Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a28cf2a029 | ||
|
|
60cb271e30 | ||
|
|
1874234d2c | ||
|
|
ef99539e69 | ||
|
|
39bd02722d | ||
|
|
f0b627b900 | ||
|
|
95ae00e01f | ||
|
|
83027b3a0f | ||
|
|
27eba3bd63 | ||
|
|
ba64050eab | ||
|
|
199ffd1095 | ||
|
|
88b9fcb8bb | ||
|
|
2f5944bc09 | ||
|
|
abdfb1abbf | ||
|
|
2f07248211 | ||
|
|
9ae6fc822b | ||
|
|
5094359c4e | ||
|
|
28801b54bf | ||
|
|
4d69f0e837 | ||
|
|
d91b3f8042 | ||
|
|
03a0ab2fcf | ||
|
|
d860d62972 | ||
|
|
add4893939 | ||
|
|
eb251d6e37 |
1
.gitignore
vendored
@@ -16,3 +16,4 @@ scenes/
|
||||
!scenes/infinity-quest-dynamic-scenario/infinity-quest.json
|
||||
!scenes/infinity-quest/assets/
|
||||
!scenes/infinity-quest/infinity-quest.json
|
||||
tts_voice_samples/*.wav
|
||||
25
Dockerfile.backend
Normal file
@@ -0,0 +1,25 @@
|
||||
# Use an official Python runtime as a parent image
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the current directory contents into the container at /app
|
||||
COPY ./src /app/src
|
||||
|
||||
# Copy poetry files
|
||||
COPY pyproject.toml /app/
|
||||
# If there's a poetry lock file, include the following line
|
||||
COPY poetry.lock /app/
|
||||
|
||||
# Install poetry
|
||||
RUN pip install poetry
|
||||
|
||||
# Install dependencies
|
||||
RUN poetry install --no-dev
|
||||
|
||||
# Make port 5050 available to the world outside this container
|
||||
EXPOSE 5050
|
||||
|
||||
# Run backend server
|
||||
CMD ["poetry", "run", "python", "src/talemate/server/run.py", "runserver", "--host", "0.0.0.0", "--port", "5050"]
|
||||
17
Dockerfile.frontend
Normal file
@@ -0,0 +1,17 @@
|
||||
# Use an official node runtime as a parent image
|
||||
FROM node:20
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the frontend directory contents into the container at /app
|
||||
COPY ./talemate_frontend /app
|
||||
|
||||
# Install any needed packages specified in package.json
|
||||
RUN npm install
|
||||
|
||||
# Make port 8080 available to the world outside this container
|
||||
EXPOSE 8080
|
||||
|
||||
# Run frontend server
|
||||
CMD ["npm", "run", "serve"]
|
||||
255
README.md
@@ -1,74 +1,67 @@
|
||||
# Talemate
|
||||
|
||||
Allows you to play roleplay scenarios with large language models.
|
||||
Roleplay with AI with a focus on strong narration and consistent world and game state tracking.
|
||||
|
||||
|
||||
|||
|
||||
|||
|
||||
|------------------------------------------|------------------------------------------|
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio. 0.18.0 also adds support for generic OpenAI api implementations, but generation quality on that will vary.**
|
||||
Supported APIs:
|
||||
- [OpenAI](https://platform.openai.com/overview)
|
||||
- [Anthropic](https://www.anthropic.com/)
|
||||
- [mistral.ai](https://mistral.ai/)
|
||||
- [Cohere](https://www.cohere.com/)
|
||||
- [Groq](https://www.groq.com/)
|
||||
- [Google Gemini](https://console.cloud.google.com/)
|
||||
|
||||
This means you need to either have:
|
||||
- an [OpenAI](https://platform.openai.com/overview) api key
|
||||
- setup local (or remote via runpod) LLM inference via:
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
- Any other OpenAI api implementation that implements the v1/completions endpoint
|
||||
- tested llamacpp with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
Supported self-hosted APIs:
|
||||
- [KoboldCpp](https://koboldai.org/cpp) ([Local](https://koboldai.org/cpp), [Runpod](https://koboldai.org/runpodcpp), [VastAI](https://koboldai.org/vastcpp), also includes image gen support)
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
|
||||
## Current features
|
||||
Generic OpenAI api implementations (tested and confirmed working):
|
||||
- [DeepInfra](https://deepinfra.com/)
|
||||
- [llamacpp](https://github.com/ggerganov/llama.cpp) with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
|
||||
- responive modern ui
|
||||
- agents
|
||||
- conversation: handles character dialogue
|
||||
- narration: handles narrative exposition
|
||||
- summarization: handles summarization to compress context while maintain history
|
||||
- director: can be used to direct the story / characters
|
||||
- editor: improves AI responses (very hit and miss at the moment)
|
||||
- world state: generates world snapshot and handles passage of time (objects and characters)
|
||||
- creator: character / scenario creator
|
||||
- tts: text to speech via elevenlabs, coqui studio, coqui local
|
||||
- multi-client support (agents can be connected to separate APIs)
|
||||
- long term memory
|
||||
- chromadb integration
|
||||
- passage of time
|
||||
- narrative world state
|
||||
- Automatically keep track and reinforce selected character and world truths / states.
|
||||
- narrative tools
|
||||
- creative tools
|
||||
- manage multiple NPCs
|
||||
- AI backed character creation with template support (jinja2)
|
||||
- AI backed scenario creation
|
||||
- context managegement
|
||||
- Manage character details and attributes
|
||||
- Manage world information / past events
|
||||
- Pin important information to the context (Manually or conditionally through AI)
|
||||
- runpod integration
|
||||
- overridable templates for all prompts. (jinja2)
|
||||
## Core Features
|
||||
|
||||
## Planned features
|
||||
- Multiple AI agents for dialogue, narration, summarization, direction, editing, world state management, character/scenario creation, text-to-speech, and visual generation
|
||||
- Support for multiple AI clients and APIs
|
||||
- Long-term memory using ChromaDB and passage of time tracking
|
||||
- Narrative world state management to reinforce character and world truths
|
||||
- Creative tools for managing NPCs, AI-assisted character, and scenario creation with template support
|
||||
- Context management for character details, world information, past events, and pinned information
|
||||
- Integration with Runpod
|
||||
- Customizable templates for all prompts using Jinja2
|
||||
- Modern, responsive UI
|
||||
|
||||
Kinda making it up as i go along, but i want to lean more into gameplay through AI, keeping track of gamestates, moving away from simply roleplaying towards a more game-ified experience.
|
||||
# Instructions
|
||||
|
||||
In no particular order:
|
||||
Please read the documents in the `docs` folder for more advanced configuration and usage.
|
||||
|
||||
|
||||
- Extension support
|
||||
- modular agents and clients
|
||||
- Improved world state
|
||||
- Dynamic player choice generation
|
||||
- Better creative tools
|
||||
- node based scenario / character creation
|
||||
- Improved and consistent long term memory and accurate current state of the world
|
||||
- Improved director agent
|
||||
- Right now this doesn't really work well on anything but GPT-4 (and even there it's debatable). It tends to steer the story in a way that introduces pacing issues. It needs a model that is creative but also reasons really well i think.
|
||||
- Gameplay loop governed by AI
|
||||
- objectives
|
||||
- quests
|
||||
- win / lose conditions
|
||||
- stable-diffusion client for in place visual generation
|
||||
- [Quickstart](#quickstart)
|
||||
- [Installation](#installation)
|
||||
- [Windows](#windows)
|
||||
- [Linux](#linux)
|
||||
- [Docker](#docker)
|
||||
- [Connecting to an LLM](#connecting-to-an-llm)
|
||||
- [OpenAI / mistral.ai / Anthropic](#openai--mistralai--anthropic)
|
||||
- [Text-generation-webui / LMStudio](#text-generation-webui--lmstudio)
|
||||
- [Specifying the correct prompt template](#specifying-the-correct-prompt-template)
|
||||
- [Recommended Models](#recommended-models)
|
||||
- [DeepInfra via OpenAI Compatible client](#deepinfra-via-openai-compatible-client)
|
||||
- [Google Gemini](#google-gemini)
|
||||
- [Google Cloud Setup](#google-cloud-setup)
|
||||
- [Ready to go](#ready-to-go)
|
||||
- [Load the introductory scenario "Infinity Quest"](#load-the-introductory-scenario-infinity-quest)
|
||||
- [Loading character cards](#loading-character-cards)
|
||||
- [Text-to-Speech (TTS)](docs/tts.md)
|
||||
- [Visual Generation](docs/visual.md)
|
||||
- [ChromaDB (long term memory) configuration](docs/chromadb.md)
|
||||
- [Runpod Integration](docs/runpod.md)
|
||||
- [Prompt template overrides](docs/templates.md)
|
||||
|
||||
# Quickstart
|
||||
|
||||
@@ -93,39 +86,44 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
|
||||
|
||||
`nodejs v19 or v20` :warning: `v21` not supported yet.
|
||||
|
||||
1. `git clone git@github.com:vegu-ai/talemate`
|
||||
1. `git clone https://github.com/vegu-ai/talemate.git`
|
||||
1. `cd talemate`
|
||||
1. `source install.sh`
|
||||
1. Start the backend: `python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
|
||||
1. Open a new terminal, navigate to the `talemate_frontend` directory, and start the frontend server by running `npm run serve`.
|
||||
|
||||
## Connecting to an LLM
|
||||
### Docker
|
||||
|
||||
1. `git clone https://github.com/vegu-ai/talemate.git`
|
||||
1. `cd talemate`
|
||||
1. `docker-compose up`
|
||||
1. Navigate your browser to http://localhost:8080
|
||||
|
||||
:warning: When connecting local APIs running on the hostmachine (e.g. text-generation-webui), you need to use `host.docker.internal` as the hostname.
|
||||
|
||||
#### To shut down the Docker container
|
||||
|
||||
Just closing the terminal window will not stop the Docker container. You need to run `docker-compose down` to stop the container.
|
||||
|
||||
#### How to install Docker
|
||||
|
||||
1. Download and install Docker Desktop from the [official Docker website](https://www.docker.com/products/docker-desktop).
|
||||
|
||||
# Connecting to an LLM
|
||||
|
||||
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
|
||||
|
||||

|
||||
|
||||
### Text-generation-webui
|
||||

|
||||
|
||||
> :warning: As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
|
||||
## OpenAI / mistral.ai / Anthropic
|
||||
|
||||
In the modal if you're planning to connect to text-generation-webui, you can likely leave everything as is and just click Save.
|
||||
|
||||

|
||||
|
||||
|
||||
#### Recommended Models
|
||||
|
||||
Any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
|
||||
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/18yp9u4/llm_comparisontest_api_edition_gpt4_vs_gemini_vs/
|
||||
|
||||
|
||||
### OpenAI
|
||||
The setup is the same for all three, the example below is for OpenAI.
|
||||
|
||||
If you want to add an OpenAI client, just change the client type and select the apropriate model.
|
||||
|
||||

|
||||

|
||||
|
||||
If you are setting this up for the first time, you should now see the client, but it will have a red dot next to it, stating that it requires an API key.
|
||||
|
||||
@@ -133,17 +131,106 @@ If you are setting this up for the first time, you should now see the client, bu
|
||||
|
||||
Click the `SET API KEY` button. This will open a modal where you can enter your API key.
|
||||
|
||||

|
||||

|
||||
|
||||
Click `Save` and after a moment the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
## Text-generation-webui / LMStudio
|
||||
|
||||
> :warning: As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
|
||||
|
||||
In the modal if you're planning to connect to text-generation-webui, you can likely leave everything as is and just click Save.
|
||||
|
||||

|
||||
|
||||
### Specifying the correct prompt template
|
||||
|
||||
For good results it is **vital** that the correct prompt template is specified for whichever model you have loaded.
|
||||
|
||||
Talemate does come with a set of pre-defined templates for some popular models, but going forward, due to the sheet number of models released every day, understanding and specifying the correct prompt template is something you should familiarize yourself with.
|
||||
|
||||
If the text-gen-webui client shows a yellow triangle next to it, it means that the prompt template is not set, and it is currently using the default `VICUNA` style prompt template.
|
||||
|
||||

|
||||
|
||||
Click the two cogwheels to the right of the triangle to open the client settings.
|
||||
|
||||

|
||||
|
||||
You can first try by clicking the `DETERMINE VIA HUGGINGFACE` button, depending on the model's README file, it may be able to determine the correct prompt template for you. (basically the readme needs to contain an example of the template)
|
||||
|
||||
If that doesn't work, you can manually select the prompt template from the dropdown.
|
||||
|
||||
In the case for `bartowski_Nous-Hermes-2-Mistral-7B-DPO-exl2_8_0` that is `ChatML` - select it from the dropdown and click `Save`.
|
||||
|
||||

|
||||
|
||||
### Recommended Models
|
||||
|
||||
As of 2024.05.06 my personal regular drivers (the ones i test with) are:
|
||||
|
||||
- meta-llama_Meta-Llama-3-8B-Instruct
|
||||
- brucethemoose_Yi-34B-200K-RPMerge
|
||||
- rAIfle_Verdict-8x7B
|
||||
- meta-llama_Meta-Llama-3-70B-Instruct
|
||||
|
||||
That said, any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
|
||||
|
||||
[https://oobabooga.github.io/benchmark.html](https://oobabooga.github.io/benchmark.html)
|
||||
|
||||
## DeepInfra via OpenAI Compatible client
|
||||
|
||||
You can use the OpenAI compatible client to connect to [DeepInfra](https://deepinfra.com/).
|
||||
|
||||

|
||||
|
||||
```
|
||||
API URL: https://api.deepinfra.com/v1/openai
|
||||
```
|
||||
|
||||
Models on DeepInfra that work well with Talemate:
|
||||
|
||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://deepinfra.com/mistralai/Mixtral-8x7B-Instruct-v0.1) (max context 32k, 8k recommended)
|
||||
- [cognitivecomputations/dolphin-2.6-mixtral-8x7b](https://deepinfra.com/cognitivecomputations/dolphin-2.6-mixtral-8x7b) (max context 32k, 8k recommended)
|
||||
- [lizpreciatior/lzlv_70b_fp16_hf](https://deepinfra.com/lizpreciatior/lzlv_70b_fp16_hf) (max context 4k)
|
||||
|
||||
## Google Gemini
|
||||
|
||||
### Google Cloud Setup
|
||||
|
||||
Unlike the other clients the setup for Google Gemini is a bit more involved as you will need to set up a google cloud project and credentials for it.
|
||||
|
||||
Please follow their [instructions for setup](https://cloud.google.com/vertex-ai/docs/start/client-libraries) - which includes setting up a project, enabling the Vertex AI API, creating a service account, and downloading the credentials.
|
||||
|
||||
Once you have downloaded the credentials, copy the JSON file into the talemate directory. You can rename it to something that's easier to remember, like `my-credentials.json`.
|
||||
|
||||
### Add the client
|
||||
|
||||

|
||||
|
||||
The `Disable Safety Settings` option will turn off the google reponse validation for what they consider harmful content. Use at your own risk.
|
||||
|
||||
### Conmplete the google cloud setup in talemate
|
||||
|
||||

|
||||
|
||||
Click the `SETUP GOOGLE API CREDENTIALS` button that will appear on the client.
|
||||
|
||||
The google cloud setup modal will appear, fill in the path to the credentials file and select a location that is close to you.
|
||||
|
||||

|
||||
|
||||
Click save and after a moment the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
## Ready to go
|
||||
|
||||
You will know you are good to go when the client and all the agents have a green dot next to them.
|
||||
|
||||

|
||||

|
||||
|
||||
## Load the introductory scenario "Infinity Quest"
|
||||
|
||||
@@ -164,13 +251,3 @@ Expand the "Load" menu in the top left corner and either click on "Upload a char
|
||||
Once a character is uploaded, talemate may actually take a moment because it needs to convert it to a talemate format and will also run additional LLM prompts to generate character attributes and world state.
|
||||
|
||||
Make sure you save the scene after the character is loaded as it can then be loaded as normal talemate scenario in the future.
|
||||
|
||||
## Further documentation
|
||||
|
||||
Please read the documents in the `docs` folder for more advanced configuration and usage.
|
||||
|
||||
- [Prompt template overrides](docs/templates.md)
|
||||
- [Text-to-Speech (TTS)](docs/tts.md)
|
||||
- [ChromaDB (long term memory)](docs/chromadb.md)
|
||||
- [Runpod Integration](docs/runpod.md)
|
||||
- Creative mode
|
||||
|
||||
@@ -48,6 +48,7 @@ game:
|
||||
# embeddings: instructor
|
||||
# instructor_device: cuda
|
||||
# instructor_model: hkunlp/instructor-xl
|
||||
# openai_model: text-embedding-3-small
|
||||
|
||||
## Remote LLMs
|
||||
|
||||
|
||||
27
docker-compose.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
talemate-backend:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.backend
|
||||
ports:
|
||||
- "5050:5050"
|
||||
volumes:
|
||||
# can uncomment for dev purposes
|
||||
#- ./src/talemate:/app/src/talemate
|
||||
- ./config.yaml:/app/config.yaml
|
||||
- ./scenes:/app/scenes
|
||||
- ./templates:/app/templates
|
||||
- ./chroma:/app/chroma
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
|
||||
talemate-frontend:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.frontend
|
||||
ports:
|
||||
- "8080:8080"
|
||||
volumes:
|
||||
- ./talemate_frontend:/app
|
||||
@@ -56,6 +56,7 @@ Then add the following to `config.yaml` for chromadb:
|
||||
```yaml
|
||||
chromadb:
|
||||
embeddings: openai
|
||||
openai_model: text-embedding-3-small
|
||||
```
|
||||
|
||||
**Note**: As with everything openai, using this isn't free. It's way cheaper than their text completion though. ALSO - if you send super explicit content they may flag / ban your key, so keep that in mind (i hear they usually send warnings first though), and always monitor your usage on their dashboard.
|
||||
**Note**: As with everything openai, using this isn't free. It's way cheaper than their text completion though. Always monitor your usage on their dashboard.
|
||||
|
||||
48
docs/dev/agents/example/test/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from talemate.agents.base import Agent, AgentAction
|
||||
from talemate.agents.registry import register
|
||||
from talemate.events import GameLoopEvent
|
||||
import talemate.emit.async_signals
|
||||
from talemate.emit import emit
|
||||
|
||||
@register()
|
||||
class TestAgent(Agent):
|
||||
|
||||
agent_type = "test"
|
||||
verbose_name = "Test"
|
||||
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"test": AgentAction(
|
||||
enabled=True,
|
||||
label="Test",
|
||||
description="Test",
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called on the beginning of every game loop
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
emit("status", status="info", message="Annoying you with a test message every game loop.")
|
||||
130
docs/dev/client/example/runpod_vllm/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
An attempt to write a client against the runpod serverless vllm worker.
|
||||
|
||||
This is close to functional, but since runpod serverless gpu availability is currently terrible, i have
|
||||
been unable to properly test it.
|
||||
|
||||
Putting it here for now since i think it makes a decent example of how to write a client against a new service.
|
||||
"""
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import runpod
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.runpod_vllm")
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
runpod_id: str = ""
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
runpod_id: str = ""
|
||||
|
||||
@register()
|
||||
class RunPodVLLMClient(ClientBase):
|
||||
client_type = "runpod_vllm"
|
||||
conversation_retries = 5
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title: str = "Runpod VLLM"
|
||||
name_prefix: str = "Runpod VLLM"
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"runpod_id": ExtraField(
|
||||
name="runpod_id",
|
||||
type="text",
|
||||
label="Runpod ID",
|
||||
required=True,
|
||||
description="The Runpod ID to connect to.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, model=None, runpod_id=None, **kwargs):
|
||||
self.model_name = model
|
||||
self.runpod_id = runpod_id
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
|
||||
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
|
||||
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
return self.model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
prompt = prompt.strip()
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
|
||||
|
||||
run_request = await endpoint.run({
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
}
|
||||
#"parameters": parameters
|
||||
})
|
||||
|
||||
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
|
||||
status = await run_request.status()
|
||||
log.debug("generate", status=status)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
status = await run_request.status()
|
||||
|
||||
log.debug("generate", status=status)
|
||||
|
||||
response = await run_request.output()
|
||||
|
||||
log.debug("generate", response=response)
|
||||
|
||||
return response["choices"][0]["tokens"][0]
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
if "runpod_id" in kwargs:
|
||||
self.api_auth = kwargs["runpod_id"]
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
self.set_client(**kwargs)
|
||||
67
docs/dev/client/example/test/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pydantic
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
|
||||
@register()
|
||||
class TestClient(ClientBase):
|
||||
client_type = "test"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "test"
|
||||
title: str = "Test"
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
|
||||
"""
|
||||
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
|
||||
|
||||
This method is called before the prompt is sent to the client, and it allows the client to remove
|
||||
any parameters that it doesn't support.
|
||||
"""
|
||||
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
"""
|
||||
This should return the name of the model that is being used.
|
||||
"""
|
||||
|
||||
return "Mock test model"
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
BIN
docs/img/0.19.0/Screenshot_15.png
Normal file
|
After Width: | Height: | Size: 418 KiB |
BIN
docs/img/0.19.0/Screenshot_16.png
Normal file
|
After Width: | Height: | Size: 413 KiB |
BIN
docs/img/0.19.0/Screenshot_17.png
Normal file
|
After Width: | Height: | Size: 364 KiB |
BIN
docs/img/0.20.0/comfyui-base-workflow.png
Normal file
|
After Width: | Height: | Size: 128 KiB |
BIN
docs/img/0.20.0/visual-config-a1111.png
Normal file
|
After Width: | Height: | Size: 32 KiB |
BIN
docs/img/0.20.0/visual-config-comfyui.png
Normal file
|
After Width: | Height: | Size: 34 KiB |
BIN
docs/img/0.20.0/visual-config-openai.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
docs/img/0.20.0/visual-queue.png
Normal file
|
After Width: | Height: | Size: 933 KiB |
BIN
docs/img/0.20.0/visualize-scene-tools.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
docs/img/0.20.0/visualizer-busy.png
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
docs/img/0.20.0/visualizer-ready.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
docs/img/0.20.0/visualze-new-images.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
BIN
docs/img/0.21.0/deepinfra-setup.png
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
docs/img/0.21.0/no-clients.png
Normal file
|
After Width: | Height: | Size: 7.1 KiB |
BIN
docs/img/0.21.0/openai-add-api-key.png
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
docs/img/0.21.0/openai-setup.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
docs/img/0.21.0/prompt-template-default.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
docs/img/0.21.0/ready-to-go.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/img/0.21.0/select-prompt-template.png
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
docs/img/0.21.0/selected-prompt-template.png
Normal file
|
After Width: | Height: | Size: 49 KiB |
BIN
docs/img/0.21.0/text-gen-webui-setup.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
BIN
docs/img/0.25.0/google-add-client.png
Normal file
|
After Width: | Height: | Size: 25 KiB |
BIN
docs/img/0.25.0/google-cloud-setup.png
Normal file
|
After Width: | Height: | Size: 59 KiB |
BIN
docs/img/0.25.0/google-ready.png
Normal file
|
After Width: | Height: | Size: 5.3 KiB |
BIN
docs/img/0.25.0/google-setup-incomplete.png
Normal file
|
After Width: | Height: | Size: 7.5 KiB |
15
docs/tts.md
@@ -17,21 +17,6 @@ elevenlabs:
|
||||
api_key: <YOUR_ELEVENLABS_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Coqui TTS
|
||||
|
||||
To use Coqui TTS with Talemate, follow these steps:
|
||||
|
||||
1. Visit [Coqui](https://app.coqui.ai) and sign up for an account.
|
||||
2. Go to the [account page](https://app.coqui.ai/account) and scroll to the bottom to find your API key.
|
||||
3. In the `config.yaml` file, under the `coqui` section, set the `api_key` field with your Coqui API key.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
coqui:
|
||||
api_key: <YOUR_COQUI_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Local TTS API
|
||||
|
||||
For running a local TTS API, Talemate requires specific dependencies to be installed.
|
||||
|
||||
117
docs/visual.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# Visual Agent
|
||||
|
||||
The visual agent currently allows for some bare bones visual generation using various stable-diffusion APIs. This is early development and experimental.
|
||||
|
||||
Its important to note that the visualization agent actually specifies two clients. One is the backend for the visual generation, and the other is the text generation client to use for prompt generation.
|
||||
|
||||
The client for prompt generation can be assigned to the agent as you would for any other agent. The client for visual generation is assigned in the Visualizer config.
|
||||
|
||||
## Index
|
||||
|
||||
- [OpenAI](#openai)
|
||||
- [AUTOMATIC1111](#automatic1111)
|
||||
- [ComfyUI](#comfyui)
|
||||
- [How to use](#how-to-use)
|
||||
|
||||
## OpenAI
|
||||
|
||||
Most straightforward to use, as it runs on the OpenAI API. You will need to have an API key and set it in the application config.
|
||||
|
||||

|
||||
|
||||
Then open the Visualizer config by clicking the agent's name in the agent list and choose `OpenAI` as the backend.
|
||||
|
||||

|
||||
|
||||
Note: `Client` here refers to the text-generation client to use for prompt generation. While `Backend` refers to the visual generation backend. You are **NOT** required to use the OpenAI client for prompt generation even if you are using the OpenAI backend for image generation.
|
||||
|
||||
## AUTOMATIC1111
|
||||
|
||||
This requires you to setup a local instance of the AUTOMATIC1111 API. Follow the instructions from their [GitHub](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to get it running.
|
||||
|
||||
Once you have it running, you will want to adjust the `webui-user.bat` in the AUTOMATIC1111 directory to include the following command arguments:
|
||||
|
||||
```bat
|
||||
set COMMANDLINE_ARGS=--api --listen --port 7861
|
||||
```
|
||||
|
||||
Then run the `webui-user.bat` to start the API.
|
||||
|
||||
Once your AUTOAMTIC1111 API is running (check with your browser) you can set the Visualizer config to use the `AUTOMATIC1111` backend
|
||||
|
||||

|
||||
|
||||
#### Extra Configuration
|
||||
|
||||
- `api url`: the url of the API, usually `http://localhost:7861`
|
||||
- `steps`: render steps
|
||||
- `model type`: sdxl or sd1.5 - this will dictate the resolution of the image generation and actually matters for the quality so make sure this is set to the correct model type for the model you are using.
|
||||
|
||||
## ComfyUI
|
||||
|
||||
This requires you to setup a local instance of the ComfyUI API. Follow the instructions from their [GitHub](https://github.com/comfyanonymous/ComfyUI) to get it running.
|
||||
|
||||
Once you're setup, copy their `start.bat` file to a new `start-listen.bat` file and change the contents to.
|
||||
|
||||
```bat
|
||||
call venv\Scripts\activate
|
||||
call python main.py --port 8188 --listen 0.0.0.0
|
||||
```
|
||||
|
||||
Then run the `start-listen.bat` to start the API.
|
||||
|
||||
Once your ComfyUI API is running (check with your browser) you can set the Visualizer config to use the `ComfyUI` backend.
|
||||
|
||||

|
||||
|
||||
### Extra Configuration
|
||||
|
||||
- `api url`: the url of the API, usually `http://localhost:8188`
|
||||
- `workflow`: the workflow file to use. This is a comfyui api workflow file that needs to exist in `./templates/comfyui-workflows` inside the talemate directory. Talemate provides two very barebones workflows with `default-sdxl.json` and `default-sd15.json`. You can create your own workflows and place them in this directory to use them. :warning: The workflow file must be generated using the API Workflow export not the UI export. Please refer to their documentation for more information.
|
||||
- `checkpoint`: the model to use - this will load a list of all available models in your comfyui instance. Select which one you want to use for the image generation.
|
||||
|
||||
### Custom Workflows
|
||||
|
||||
When creating custom workflows for ideal compatibility with Talemate, ensure the following.
|
||||
|
||||
- A `CheckpointLoaderSimple` node named `Talemate Load Checkpoint`
|
||||
- A `EmptyLatentImage` node name `Talemate Resolution`
|
||||
- A `ClipTextEncode` node named `Talemate Positive Prompt`
|
||||
- A `ClipTextEncode` node named `Talemate Negative Prompt`
|
||||
- A `SaveImage` node at the end of the workflow.
|
||||
|
||||

|
||||
|
||||
## How to use
|
||||
|
||||
Once you're done setting up the visualizer agent should have a green dot next to it and display both the selected image generation backend and the selected prompt generation client.
|
||||
|
||||

|
||||
|
||||
Your hotbar should then also enable the visualization menu for you to use (once you have a scene loaded).
|
||||
|
||||

|
||||
|
||||
Right now you can generate a portrait for any NPC in the scene or a background image for the scene itself.
|
||||
|
||||
Image generation by default will actually happen in the background, allowing you to continue using Talemate while the image is being generated.
|
||||
|
||||
You can tell if an image is being generated by the blueish spinner next to the visualization agent.
|
||||
|
||||

|
||||
|
||||
Once the image is generated, it will be avaible for you to view via the visual queue button on top of the screen.
|
||||
|
||||

|
||||
|
||||
Click it to open the visual queue and view the generated images.
|
||||
|
||||

|
||||
|
||||
### Character Portrait
|
||||
|
||||
For character potraits you can chose whether or not to replace the main portrait for the character (the one being displated in the left sidebar when a talemate scene is active).
|
||||
|
||||
### Background Image
|
||||
|
||||
Right now there is nothing to do with the background image, other than to view it in the visual queue. More functionality will be added in the future.
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# create a virtual environment
|
||||
python -m venv talemate_env
|
||||
python3 -m venv talemate_env
|
||||
|
||||
# activate the virtual environment
|
||||
source talemate_env/bin/activate
|
||||
|
||||
3931
poetry.lock
generated
@@ -4,13 +4,13 @@ build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.18.1"
|
||||
version = "0.25.2"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
python = ">=3.10,<3.12"
|
||||
astroid = "^2.8"
|
||||
jedi = "^0.18"
|
||||
black = "*"
|
||||
@@ -18,9 +18,13 @@ rope = "^0.22"
|
||||
isort = "^5.10"
|
||||
jinja2 = "^3.0"
|
||||
openai = ">=1"
|
||||
mistralai = ">=0.1.8"
|
||||
cohere = ">=5.2.2"
|
||||
anthropic = ">=0.19.1"
|
||||
groq = ">=0.5.0"
|
||||
requests = "^2.26"
|
||||
colorama = ">=0.4.6"
|
||||
Pillow = "^9.5"
|
||||
Pillow = ">=9.5"
|
||||
httpx = "<1"
|
||||
piexif = "^1.1"
|
||||
typing-inspect = "0.8.0"
|
||||
@@ -33,17 +37,20 @@ python-dotenv = "^1.0.0"
|
||||
websockets = "^11.0.3"
|
||||
structlog = "^23.1.0"
|
||||
runpod = "^1.2.0"
|
||||
google-cloud-aiplatform = ">=1.50.0"
|
||||
nest_asyncio = "^1.5.7"
|
||||
isodate = ">=0.6.1"
|
||||
thefuzz = ">=0.20.0"
|
||||
tiktoken = ">=0.5.1"
|
||||
nltk = ">=3.8.1"
|
||||
huggingface-hub = ">=0.20.2"
|
||||
RestrictedPython = ">7.1"
|
||||
|
||||
# ChromaDB
|
||||
chromadb = ">=0.4.17,<1"
|
||||
InstructorEmbedding = "^1.0.1"
|
||||
torch = ">=2.1.0"
|
||||
torchaudio = ">=2.3.0"
|
||||
sentence-transformers="^2.2.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
|
After Width: | Height: | Size: 1.6 MiB |
@@ -98,6 +98,7 @@
|
||||
}
|
||||
],
|
||||
"immutable_save": true,
|
||||
"experimental": true,
|
||||
"goal": null,
|
||||
"goals": [],
|
||||
"context": "an epic sci-fi adventure aimed at an adult audience.",
|
||||
@@ -109,10 +110,10 @@
|
||||
"variables": {}
|
||||
},
|
||||
"assets": {
|
||||
"cover_image": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
|
||||
"cover_image": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
|
||||
"assets": {
|
||||
"52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df": {
|
||||
"id": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
|
||||
"e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404": {
|
||||
"id": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
|
||||
"file_type": "png",
|
||||
"media_type": "image/png"
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
{%- set _ = emit_system("warning", "This is a dynamic scenario generation experiment for Infinity Quest. It will likely require a strong LLM to generate something coherent. GPT-4 or 34B+ if local. Temper your expectations.") -%}
|
||||
|
||||
{#- emit status update to the UX -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]", as_scene_message=True) -%}
|
||||
|
||||
{#- thematic tags will be used to randomize generation -#}
|
||||
{%- set tags = thematic_generator.generate("color", "state_of_matter", "scifi_trope") -%}
|
||||
@@ -17,17 +17,17 @@
|
||||
|
||||
|
||||
{#- generate introductory text -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]", as_scene_message=True) -%}
|
||||
{%- set tmpl__scenario_intro = render_template('generate-scenario-intro', premise=instr__premise) %}
|
||||
{%- set instr__intro = "*"+render_and_request(tmpl__scenario_intro)+"*" -%}
|
||||
|
||||
{#- generate win conditions -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]", as_scene_message=True) -%}
|
||||
{%- set tmpl__win_conditions = render_template('generate-win-conditions', premise=instr__premise) %}
|
||||
{%- set instr__win_conditions = render_and_request(tmpl__win_conditions) -%}
|
||||
|
||||
{#- emit status update to the UX -#}
|
||||
{%- set status = emit_status("info", "Scenario ready.") -%}
|
||||
{%- set status = emit_status("success", "Scenario ready.", as_scene_message=True) -%}
|
||||
|
||||
{# set gamestate variables #}
|
||||
{%- set _ = game_state.set_var("instr.premise", instr__premise, commit=True) -%}
|
||||
|
||||
|
After Width: | Height: | Size: 1.7 MiB |
535
scenes/simulation-suite/game.py
Normal file
@@ -0,0 +1,535 @@
|
||||
|
||||
def game(TM):
|
||||
|
||||
MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions"
|
||||
|
||||
MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\""
|
||||
|
||||
PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice."
|
||||
|
||||
CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
AUTO_NARRATE_INTERVAL = 10
|
||||
|
||||
def parse_sim_call_arguments(call:str) -> str:
|
||||
"""
|
||||
Returns the value between the parentheses of a simulation call
|
||||
|
||||
Example:
|
||||
|
||||
call = 'change_environment("a house")'
|
||||
|
||||
parse_sim_call_arguments(call) -> "a house"
|
||||
"""
|
||||
|
||||
try:
|
||||
return call.split("(", 1)[1].split(")")[0]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
class SimulationSuite:
|
||||
|
||||
def __init__(self):
|
||||
# do we update the world state at the end of the round
|
||||
self.update_world_state = False
|
||||
|
||||
self.simulation_reset = False
|
||||
|
||||
self.added_npcs = []
|
||||
|
||||
TM.log.debug("SIMULATION SUITE INIT...")
|
||||
|
||||
self.player_character = TM.scene.get_player_character()
|
||||
self.player_message = TM.scene.last_player_message()
|
||||
self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1)
|
||||
self.player_message_is_instruction = (
|
||||
self.player_message and
|
||||
self.player_message.raw.lower().startswith("computer") and
|
||||
not self.player_message.hidden and
|
||||
not self.last_processed_call > self.player_message.id
|
||||
)
|
||||
|
||||
|
||||
def run(self):
|
||||
if not TM.game_state.has_var("instr.simulation_stopped"):
|
||||
self.simulation()
|
||||
|
||||
self.finalize_round()
|
||||
|
||||
def simulation(self):
|
||||
|
||||
if not TM.game_state.has_var("instr.simulation_started"):
|
||||
self.startup()
|
||||
else:
|
||||
self.simulation_calls()
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state(force=True)
|
||||
|
||||
|
||||
def startup(self):
|
||||
TM.emit_status("busy", "Simulation suite powering up.", as_scene_message=True)
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_STARTUP,
|
||||
emit_message=False
|
||||
)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="passthrough",
|
||||
narration=MSG_HELP
|
||||
)
|
||||
TM.agents.world_state.manager(
|
||||
action_name="save_world_entry",
|
||||
entry_id="sim.quarantined",
|
||||
text=CTX_PIN_UNAWARE,
|
||||
meta={},
|
||||
pin=True
|
||||
)
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
TM.emit_status("success", "Simulation suite ready", as_scene_message=True)
|
||||
self.update_world_state = True
|
||||
|
||||
def simulation_calls(self):
|
||||
"""
|
||||
Calls the simulation suite main prompt to determine the appropriate
|
||||
simulation calls
|
||||
"""
|
||||
|
||||
if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call:
|
||||
return
|
||||
|
||||
# First instruction?
|
||||
if not TM.game_state.has_var("instr.has_issued_instructions"):
|
||||
|
||||
# determine the context of the simulation
|
||||
|
||||
context_context = TM.agents.creator.determine_content_context_for_description(
|
||||
description=self.player_message.raw,
|
||||
)
|
||||
TM.scene.set_content_context(context_context)
|
||||
|
||||
|
||||
calls = TM.client.render_and_request(
|
||||
"computer",
|
||||
dedupe_enabled=False,
|
||||
player_instruction=self.player_message.raw,
|
||||
scene=TM.scene,
|
||||
)
|
||||
|
||||
self.calls = calls = calls.split("\n")
|
||||
|
||||
calls = self.prepare_calls(calls)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE CALLS", callse=calls)
|
||||
|
||||
# calls that are processed
|
||||
processed = []
|
||||
|
||||
for call in calls:
|
||||
processed_call = self.process_call(call)
|
||||
if processed_call:
|
||||
processed.append(processed_call)
|
||||
|
||||
|
||||
if processed:
|
||||
TM.log.debug("SIMULATION SUITE CALLS", calls=processed)
|
||||
TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False)
|
||||
|
||||
TM.emit_status("busy", "Simulation suite altering environment.", as_scene_message=True)
|
||||
compiled = "\n".join(processed)
|
||||
if not self.simulation_reset and compiled:
|
||||
narration = TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.",
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
# on the first narration we update the scene description and remove any mention of the computer
|
||||
# or the simulation from the previous narration
|
||||
is_initial_narration = TM.game_state.get_var("instr.intro_narration", False)
|
||||
if not is_initial_narration:
|
||||
TM.scene.set_description(str(narration))
|
||||
TM.scene.set_intro(str(narration))
|
||||
TM.log.debug("SIMULATION SUITE: initial narration", intro=str(narration))
|
||||
TM.scene.pop_history(typ="narrator", all=True, reverse=True)
|
||||
TM.scene.pop_history(typ="director", all=True, reverse=True)
|
||||
TM.game_state.set_var("instr.intro_narration", True, commit=False)
|
||||
|
||||
self.update_world_state = True
|
||||
|
||||
self.set_simulation_title(compiled)
|
||||
|
||||
def set_simulation_title(self, compiled_calls):
|
||||
|
||||
"""
|
||||
Generates a fitting title for the simulation based on the user's instructions
|
||||
"""
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls)
|
||||
|
||||
if not compiled_calls:
|
||||
return
|
||||
|
||||
if TM.scene.title != "Simulation Suite":
|
||||
# name already changed, no need to do it again
|
||||
return
|
||||
|
||||
title = TM.agents.creator.contextual_generate_from_args(
|
||||
"scene:simulation title",
|
||||
"Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.",
|
||||
length=75
|
||||
)
|
||||
|
||||
title = title.strip('"').strip()
|
||||
|
||||
TM.scene.set_title(title)
|
||||
|
||||
def prepare_calls(self, calls):
|
||||
"""
|
||||
Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both
|
||||
found, ensure that the `set_player_name` call is processed first by moving it in front of the
|
||||
`set_player_persona` call.
|
||||
"""
|
||||
|
||||
set_player_name_call_exists = -1
|
||||
set_player_persona_call_exists = -1
|
||||
|
||||
i = 0
|
||||
for call in calls:
|
||||
if "set_player_name" in call:
|
||||
set_player_name_call_exists = i
|
||||
elif "set_player_persona" in call:
|
||||
set_player_persona_call_exists = i
|
||||
i = i + 1
|
||||
|
||||
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
|
||||
|
||||
if set_player_name_call_exists > set_player_persona_call_exists:
|
||||
calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists))
|
||||
TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls)
|
||||
|
||||
return calls
|
||||
|
||||
def process_call(self, call:str) -> str:
|
||||
"""
|
||||
Processes a simulation call
|
||||
|
||||
Simulation alls are pseudo functions that are called by the simulation suite
|
||||
|
||||
We grab the function name by splitting against ( and taking the first element
|
||||
if the SimulationSuite has a method with the name _call_{function_name} then we call it
|
||||
|
||||
if a function name could be found but we do not have a method to call we dont do anything
|
||||
but we still return it as procssed as the AI can still interpret it as something later on
|
||||
"""
|
||||
|
||||
if "(" not in call:
|
||||
return None
|
||||
|
||||
function_name = call.split("(")[0]
|
||||
|
||||
if hasattr(self, f"call_{function_name}"):
|
||||
TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name)
|
||||
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
return getattr(self, f"call_{function_name}")(call, inject)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_set_simulation_goal(self, call:str, inject:str) -> str:
|
||||
"""
|
||||
Set's the simulation goal as a permanent pin
|
||||
"""
|
||||
TM.emit_status("busy", "Simulation suite setting goal.", as_scene_message=True)
|
||||
TM.agents.world_state.manager(
|
||||
action_name="save_world_entry",
|
||||
entry_id="sim.goal",
|
||||
text=self.player_message.raw,
|
||||
meta={},
|
||||
pin=True
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer sets the goal for the simulation.",
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def call_change_environment(self, call:str, inject:str) -> str:
|
||||
"""
|
||||
Simulation changes the environment, this is entirely interpreted by the AI
|
||||
and we dont need to do any logic on our end, so we just return the call
|
||||
"""
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer changes the environment of the simulation."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_answer_question(self, call:str, inject:str) -> str:
|
||||
"""
|
||||
The player asked the simulation a query, we need to process this and have
|
||||
the AI produce an answer
|
||||
"""
|
||||
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.",
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
|
||||
def call_set_player_persona(self, call:str, inject:str) -> str:
|
||||
|
||||
"""
|
||||
The simulation suite is altering the player persona
|
||||
"""
|
||||
|
||||
TM.emit_status("busy", "Simulation suite altering user persona.", as_scene_message=True)
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=self.player_character.name, text=inject, alteration_instructions=self.player_message.raw
|
||||
)
|
||||
self.player_character.update(base_attributes=character_attributes)
|
||||
|
||||
character_description = TM.agents.creator.determine_character_description(character=self.player_character)
|
||||
self.player_character.update(description=character_description)
|
||||
TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer transforms the player persona."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_set_player_name(self, call:str, inject:str) -> str:
|
||||
|
||||
"""
|
||||
The simulation suite is altering the player name
|
||||
"""
|
||||
|
||||
TM.emit_status("busy", "Simulation suite adjusting user identity.", as_scene_message=True)
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.")
|
||||
TM.log.debug("SIMULATION SUITE: player name", character_name=character_name)
|
||||
if character_name != self.player_character.name:
|
||||
self.player_character.rename(character_name)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer changes the player's identity to {character_name}."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_add_ai_character(self, call:str, inject:str) -> str:
|
||||
|
||||
# sometimes the AI will call this function an pass an inanimate object as the parameter
|
||||
# we need to determine if this is the case and just ignore it
|
||||
is_inanimate = TM.client.query_text_eval(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call)
|
||||
|
||||
if is_inanimate:
|
||||
TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call)
|
||||
return
|
||||
|
||||
# sometimes the AI will ask if the function adds a group of characters, we need to
|
||||
# determine if this is the case
|
||||
adds_group = TM.client.query_text_eval(f"does the function `{call}` add MULTIPLE ai characters?", call)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group)
|
||||
|
||||
TM.emit_status("busy", "Simulation suite adding character.", as_scene_message=True)
|
||||
|
||||
if not adds_group:
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.")
|
||||
else:
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True)
|
||||
|
||||
# sometimes add_ai_character and change_ai_character are called in the same instruction targeting
|
||||
# the same character, if this happens we need to combine into a single add_ai_character call
|
||||
|
||||
has_change_ai_character_call = TM.client.query_text_eval(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls))
|
||||
|
||||
if has_change_ai_character_call:
|
||||
|
||||
combined_arg = TM.client.render_and_request(
|
||||
"combine-add-and-alter-ai-character",
|
||||
dedupe_enabled=False,
|
||||
calls="\n".join(self.calls),
|
||||
character_name=character_name,
|
||||
scene=TM.scene,
|
||||
).replace("COMBINED ARGUMENT:", "").strip()
|
||||
|
||||
call = f"add_ai_character({combined_arg})"
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
|
||||
TM.emit_status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", name=character_name)
|
||||
|
||||
npc = TM.agents.director.persist_character(name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False)
|
||||
|
||||
self.added_npcs.append(npc.name)
|
||||
|
||||
TM.agents.world_state.manager(
|
||||
action_name="add_detail_reinforcement",
|
||||
character_name=npc.name,
|
||||
question="Goal",
|
||||
instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation",
|
||||
interval=25,
|
||||
run_immediately=True
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: added npc", npc=npc)
|
||||
|
||||
TM.agents.visual.generate_character_portrait(character_name=npc.name)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer adds {npc.name} to the simulation."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_remove_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.emit_status("busy", "Simulation suite removing character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names())
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
if npc:
|
||||
TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name)
|
||||
TM.agents.world_state.manager(action_name="deactivate_character", character_name=npc.name)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer removes {npc.name} from the simulation."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def call_change_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.emit_status("busy", "Simulation suite altering character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names())
|
||||
|
||||
if character_name in self.added_npcs:
|
||||
# we dont want to change the character if it was just added
|
||||
return
|
||||
|
||||
character_name_after = TM.agents.creator.determine_character_name(character_name=f"{inject} - what is the name of the character receiving the changes (after the changes)?")
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
if npc:
|
||||
TM.emit_status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", npc=npc)
|
||||
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(name=npc.name, alteration_instructions=self.player_message.raw)
|
||||
|
||||
npc.update(base_attributes=character_attributes)
|
||||
character_description = TM.agents.creator.determine_character_description(character=npc)
|
||||
|
||||
npc.update(description=character_description)
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description)
|
||||
|
||||
if character_name_after != character_name:
|
||||
npc.rename(character_name_after)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer transforms {npc.name}."
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def call_end_simulation(self, call:str, inject:str) -> str:
|
||||
|
||||
explicit_command = TM.client.query_text_eval("has the player explicitly asked to end the simulation?", self.player_message.raw)
|
||||
|
||||
if explicit_command:
|
||||
TM.emit_status("busy", "Simulation suite ending current simulation.", as_scene_message=True)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names())}. The player is also transformed back to their normal, non-descript persona as the form of {self.player_character.name} ceases to exist.",
|
||||
emit_message=True
|
||||
)
|
||||
TM.scene.restore()
|
||||
|
||||
self.simulation_reset = True
|
||||
|
||||
TM.game_state.unset_var("instr.has_issued_instructions")
|
||||
TM.game_state.unset_var("instr.lastprocessed_call")
|
||||
TM.game_state.unset_var("instr.simulation_started")
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer ends the simulation."
|
||||
)
|
||||
|
||||
def finalize_round(self):
|
||||
|
||||
# track rounds
|
||||
rounds = TM.game_state.get_var("instr.rounds", 0)
|
||||
|
||||
# increase rounds
|
||||
TM.game_state.set_var("instr.rounds", rounds + 1, commit=False)
|
||||
|
||||
has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions")
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state()
|
||||
|
||||
if self.player_message_is_instruction:
|
||||
self.player_message.hide()
|
||||
TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False)
|
||||
TM.emit_status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True)
|
||||
|
||||
elif self.player_message and not has_issued_instructions:
|
||||
# simulation started, player message is NOT an instruction, and player has not given
|
||||
# any instructions
|
||||
self.guide_player()
|
||||
|
||||
elif self.player_message and not TM.scene.npc_character_names():
|
||||
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
|
||||
self.narrate_round()
|
||||
|
||||
elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names() and has_issued_instructions:
|
||||
# every N rounds, narrate the round
|
||||
self.narrate_round()
|
||||
|
||||
def guide_player(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="paraphrase",
|
||||
narration=MSG_HELP,
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
def narrate_round(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_NARRATE_ROUND,
|
||||
emit_message=True
|
||||
)
|
||||
|
||||
def run_update_world_state(self, force=False):
|
||||
TM.log.debug("SIMULATION SUITE: update world state", force=force)
|
||||
TM.emit_status("busy", "Simulation suite updating world state.", as_scene_message=True)
|
||||
TM.agents.world_state.update_world_state(force=force)
|
||||
TM.emit_status("success", "Simulation suite updated world state.", as_scene_message=True)
|
||||
|
||||
SimulationSuite().run()
|
||||
53
scenes/simulation-suite/simulation-suite.json
Normal file
@@ -0,0 +1,53 @@
|
||||
{
|
||||
"name": "Simulation Suite",
|
||||
"title": "Simulation Suite",
|
||||
"environment": "scene",
|
||||
"immutable_save": true,
|
||||
"restore_from": "simulation-suite.json",
|
||||
"experimental": true,
|
||||
"help": "Address the computer by starting your statements with 'Computer, ' followed by an instruction.\n\nExamples:\n'Computer, i would like to experience an adventure on a derelict space station'\n'Computer, add a horrific alien creature that is chasing me.'",
|
||||
"description": "",
|
||||
"intro": "*You have entered the simulation suite. No simulation is currently active and you are in a non-descript space with paneled walls surrounding you. The control panel next to you is pulsating with a green light, indicating readiness to receive a prompt to start the simulation.*",
|
||||
"archived_history": [],
|
||||
"history": [],
|
||||
"ts": "PT1S",
|
||||
"characters": [
|
||||
{
|
||||
"name": "You",
|
||||
"gender": "unknown",
|
||||
"color": "cornflowerblue",
|
||||
"base_attributes": {},
|
||||
"is_player": true
|
||||
}
|
||||
],
|
||||
"context": "a simulated experience",
|
||||
"game_state": {
|
||||
"ops":{
|
||||
"run_on_start": true,
|
||||
"always_direct": true
|
||||
},
|
||||
"variables": {}
|
||||
},
|
||||
"world_state": {
|
||||
"character_name_mappings": {
|
||||
"You": [
|
||||
"user",
|
||||
"player",
|
||||
"player character",
|
||||
"user character",
|
||||
"the user",
|
||||
"the player"
|
||||
]
|
||||
}
|
||||
},
|
||||
"assets": {
|
||||
"cover_image": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
|
||||
"assets": {
|
||||
"4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103": {
|
||||
"id": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
|
||||
"file_type": "png",
|
||||
"media_type": "image/png"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
<|SECTION:EXAMPLES|>
|
||||
combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "Sarah" into a single text string argument to be passed to a single `add_ai_character` function call.
|
||||
```
|
||||
set_simulation_goal("player experiences a rollercoaster ride")
|
||||
change_environment("theme park, riding a rollercoaster")
|
||||
set_player_persona("young female experiencing rollercoaster ride")
|
||||
set_player_name("Susanne")
|
||||
add_ai_character("a female friend of player named Sarah")
|
||||
change_ai_character("Sarah hates rollercoasters")
|
||||
```
|
||||
COMBINED ARGUMENT: "a female friend of player named Sarah, Sarah hates rollercoasters"
|
||||
|
||||
TASK: combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "George" into a single text string argument to be passed to a single `add_ai_character` function call.
|
||||
```
|
||||
change_environment("building on fire")
|
||||
change_ai_character("George is injured")
|
||||
add_ai_character("a firefighter named Stephen")
|
||||
change_ai_character("Stephen is afraid of heights")
|
||||
```
|
||||
COMBINED ARGUMENT: "a firefighter named Stephen, Stephen is afraid of heights"
|
||||
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
TASK: combine the arguments of the function calls `add_ai_character` and `change_ai_character` for "{{ character_name }}" into a single text string argument to be passed to a single `add_ai_character` function call.
|
||||
```
|
||||
{{ calls }}
|
||||
```
|
||||
{{ set_prepared_response("COMBINED ARGUMENT:") }}
|
||||
132
scenes/simulation-suite/templates/computer.jinja2
Normal file
@@ -0,0 +1,132 @@
|
||||
<|SECTION:CONTEXT|>
|
||||
{% set scene_history=scene.context_history(budget=1024) %}
|
||||
{% for scene_context in scene_history -%}
|
||||
{{ loop.index }}. {{ scene_context }}
|
||||
{% endfor %}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:FUNCTIONS|>
|
||||
The player has instructed the computer to alter the current simulation.
|
||||
|
||||
You have access to the following functions, you can call as many as you want to fulfill the player's requests.
|
||||
|
||||
You must at least call one of the following functions:
|
||||
|
||||
- change_environment
|
||||
- add_ai_character
|
||||
- change_ai_character
|
||||
- remove_ai_character
|
||||
- set_player_persona
|
||||
- set_player_name
|
||||
- end_simulation
|
||||
- answer_question
|
||||
- set_simulation_goal
|
||||
|
||||
`add_ai_character` and `change_ai_character` are exclusive if they are targeting the same character.
|
||||
|
||||
Set the player persona at the beginning of a new simulation or if the player requests a change.
|
||||
|
||||
Only end the simulation if the player requests it explicitly.
|
||||
|
||||
Your response MUST ONLY CONTAIN the new simulation stack.
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:EXAMPLES|>
|
||||
Request: Computer, I want to be on a mountain top
|
||||
```simulation-stack
|
||||
change_environment("mountain top")
|
||||
set_player_persona("mountain climber")
|
||||
set_player_name("Hank")
|
||||
```
|
||||
|
||||
Request: Computer, I want to be more muscular and taller
|
||||
```simulation-stack
|
||||
set_player_persona("make player more muscular and taller")
|
||||
```
|
||||
|
||||
Request: Computer, the building should be on fire
|
||||
```simulation-stack
|
||||
change_environment("building on fire")
|
||||
```
|
||||
|
||||
Request: Computer, a rocket hits the building and George is now injured
|
||||
```simulation-stack
|
||||
change_environment("building on fire")
|
||||
change_ai_character("George is injured")
|
||||
```
|
||||
|
||||
Request: Computer, I want to experience a rollercoaster ride with a friend
|
||||
```simulation-stack
|
||||
set_simulation_goal("player experiences a rollercoaster ride")
|
||||
change_environment("theme park, riding a rollercoaster")
|
||||
set_player_persona("young female experiencing rollercoaster ride")
|
||||
set_player_name("Susanne")
|
||||
add_ai_character("a female friend of player named Sarah")
|
||||
```
|
||||
|
||||
Request: Computer, I want to experience the international space station, to experience the overview effect
|
||||
```simulation-stack
|
||||
set_simulation_goal("player experiences the overview effect")
|
||||
change_environment("international space station")
|
||||
set_player_persona("astronaut experiencing first trip to ISS")
|
||||
set_player_name("George")
|
||||
add_ai_character("astronaut named Henry")
|
||||
```
|
||||
|
||||
Request: Computer, remove the goblin and add an elven woman instead
|
||||
```simulation-stack
|
||||
remove_ai_character("goblin")
|
||||
add_ai_character("elven woman named Elune")
|
||||
```
|
||||
|
||||
Request: Computer, change the skiing instructor to be older.
|
||||
```simulation-stack
|
||||
change_ai_character("make skiing instructor older")
|
||||
```
|
||||
|
||||
Request: Computer, change my grandma to my grandpa
|
||||
```simulation-stack
|
||||
remove_ai_character("grandma")
|
||||
add_ai_character("grandpa named Steven")
|
||||
```
|
||||
|
||||
Request: Computer, remove the skiing instructor and add my friend instead.
|
||||
```simulation-stack
|
||||
remove_ai_character("skiing instructor")
|
||||
add_ai_character("player's friend named Tara")
|
||||
```
|
||||
|
||||
Request: Computer, replace the skiing instructor with my friend.
|
||||
```simulation-stack
|
||||
remove_ai_character("skiing instructor")
|
||||
add_ai_character("player's friend named Lisa")
|
||||
```
|
||||
|
||||
Request: Computer, I want to end the simulation
|
||||
```simulation-stack
|
||||
end_simulation("simulation ended")
|
||||
```
|
||||
|
||||
Request: Computer, shut down the simulation
|
||||
```simulation-stack
|
||||
end_simulation("simulation ended")
|
||||
```
|
||||
|
||||
Request: Computer, what do you know about the game of thrones?
|
||||
```simulation-stack
|
||||
answer_question("what do you know about the game of thrones?")
|
||||
```
|
||||
|
||||
Request: Computer, i want to be a wizard in a dark goblin infested dungeon in a fantasy world, looking for secret treasure and fighting goblins.
|
||||
```simulation-stack
|
||||
set_simulation_goal("player wants to find secret treasure and fight creatures")
|
||||
change_environment("dark dungeon in a fantasy world")
|
||||
set_player_persona("powerful wizard")
|
||||
set_player_name("Lanadel")
|
||||
add_ai_character("a goblin named Gobbo")
|
||||
```
|
||||
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
Respond with the simulation stack for the following request:
|
||||
|
||||
Request: {{ player_instruction }}
|
||||
{{ bot_token }}```simulation-stack
|
||||
@@ -2,4 +2,4 @@ from .agents import Agent
|
||||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.18.1"
|
||||
VERSION = "0.25.2"
|
||||
|
||||
@@ -8,4 +8,5 @@ from .narrator import NarratorAgent
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register
|
||||
from .summarize import SummarizeAgent
|
||||
from .tts import TTSAgent
|
||||
from .visual import VisualAgent
|
||||
from .world_state import WorldStateAgent
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import dataclasses
|
||||
import re
|
||||
from abc import ABC
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import pydantic
|
||||
@@ -19,6 +20,11 @@ from talemate.events import GameLoopStartEvent
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentAction",
|
||||
"AgentActionConditional",
|
||||
"AgentActionConfig",
|
||||
"AgentDetail",
|
||||
"AgentEmission",
|
||||
"set_processing",
|
||||
]
|
||||
|
||||
@@ -42,11 +48,24 @@ class AgentActionConfig(pydantic.BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentActionConditional(pydantic.BaseModel):
|
||||
attribute: str
|
||||
value: Union[int, float, str, bool, None] = None
|
||||
|
||||
|
||||
class AgentAction(pydantic.BaseModel):
|
||||
enabled: bool = True
|
||||
label: str
|
||||
description: str = ""
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
|
||||
|
||||
class AgentDetail(pydantic.BaseModel):
|
||||
value: Union[str, None] = None
|
||||
description: Union[str, None] = None
|
||||
icon: Union[str, None] = None
|
||||
color: str = "grey"
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
@@ -58,6 +77,7 @@ def set_processing(fn):
|
||||
the function fails.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
with ActiveAgent(self, fn):
|
||||
try:
|
||||
@@ -71,8 +91,7 @@ def set_processing(fn):
|
||||
# some concurrency error?
|
||||
log.error("error emitting agent status", exc=exc)
|
||||
|
||||
wrapper.__name__ = fn.__name__
|
||||
|
||||
wrapper.exposed = True
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -86,6 +105,9 @@ class Agent(ABC):
|
||||
set_processing = set_processing
|
||||
requires_llm_client = True
|
||||
auto_break_repetition = False
|
||||
websocket_handler = None
|
||||
essential = True
|
||||
ready_check_error = None
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
@@ -110,13 +132,20 @@ class Agent(ABC):
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.ready:
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
return "idle" if getattr(self, "processing", 0) == 0 else "busy"
|
||||
else:
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
|
||||
if not self.ready:
|
||||
return "uninitialized"
|
||||
|
||||
if getattr(self, "processing", 0) > 0:
|
||||
return "busy"
|
||||
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg"
|
||||
|
||||
return "idle"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
# by default, agents are enabled, an agent class that
|
||||
@@ -160,7 +189,48 @@ class Agent(ABC):
|
||||
|
||||
return config_options
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
@property
|
||||
def meta(self):
|
||||
return {
|
||||
"essential": self.essential,
|
||||
}
|
||||
|
||||
@property
|
||||
def sanitized_action_config(self):
|
||||
if not getattr(self, "actions", None):
|
||||
return {}
|
||||
|
||||
return {k: v.model_dump() for k, v in self.actions.items()}
|
||||
|
||||
async def _handle_ready_check(self, fut: asyncio.Future):
|
||||
callback_failure = getattr(self, "on_ready_check_failure", None)
|
||||
if fut.cancelled():
|
||||
if callback_failure:
|
||||
await callback_failure()
|
||||
return
|
||||
|
||||
if fut.exception():
|
||||
exc = fut.exception()
|
||||
self.ready_check_error = exc
|
||||
log.error("agent ready check error", agent=self.agent_type, exc=exc)
|
||||
if callback_failure:
|
||||
await callback_failure(exc)
|
||||
return
|
||||
|
||||
callback = getattr(self, "on_ready_check_success", None)
|
||||
if callback:
|
||||
await callback()
|
||||
|
||||
async def ready_check(self, task: asyncio.Task = None):
|
||||
self.ready_check_error = None
|
||||
if task:
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_ready_check(fut))
|
||||
)
|
||||
return
|
||||
return True
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
if self.has_toggle and "enabled" in kwargs:
|
||||
self.is_enabled = kwargs.get("enabled", False)
|
||||
|
||||
@@ -228,27 +298,55 @@ class Agent(ABC):
|
||||
if getattr(self, "processing", None) is None:
|
||||
self.processing = 0
|
||||
|
||||
if not processing:
|
||||
if processing is False:
|
||||
self.processing -= 1
|
||||
self.processing = max(0, self.processing)
|
||||
else:
|
||||
elif processing is True:
|
||||
self.processing += 1
|
||||
|
||||
status = "busy" if self.processing > 0 else "idle"
|
||||
if not self.enabled:
|
||||
status = "disabled"
|
||||
|
||||
emit(
|
||||
"agent_status",
|
||||
message=self.verbose_name or "",
|
||||
id=self.agent_type,
|
||||
status=status,
|
||||
status=self.status,
|
||||
details=self.agent_details,
|
||||
meta=self.meta,
|
||||
data=self.config_options(agent=self),
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def _handle_background_processing(self, fut: asyncio.Future):
|
||||
try:
|
||||
if fut.cancelled():
|
||||
return
|
||||
|
||||
if fut.exception():
|
||||
log.error(
|
||||
"background processing error",
|
||||
agent=self.agent_type,
|
||||
exc=fut.exception(),
|
||||
)
|
||||
await self.emit_status()
|
||||
return
|
||||
|
||||
log.info("background processing done", agent=self.agent_type)
|
||||
finally:
|
||||
self.processing_bg -= 1
|
||||
await self.emit_status()
|
||||
|
||||
async def set_background_processing(self, task: asyncio.Task):
|
||||
log.info("set_background_processing", agent=self.agent_type)
|
||||
if not hasattr(self, "processing_bg"):
|
||||
self.processing_bg = 0
|
||||
|
||||
self.processing_bg += 1
|
||||
|
||||
await self.emit_status()
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_background_processing(fut))
|
||||
)
|
||||
|
||||
def connect(self, scene):
|
||||
self.scene = scene
|
||||
talemate.emit.async_signals.get("game_loop_start").connect(
|
||||
|
||||
@@ -13,6 +13,7 @@ active_agent = contextvars.ContextVar("active_agent", default=None)
|
||||
class ActiveAgentContext(pydantic.BaseModel):
|
||||
agent: object
|
||||
fn: Callable
|
||||
agent_stack: list = pydantic.Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -21,12 +22,23 @@ class ActiveAgentContext(pydantic.BaseModel):
|
||||
def action(self):
|
||||
return self.fn.__name__
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.agent.verbose_name}.{self.action}"
|
||||
|
||||
|
||||
class ActiveAgent:
|
||||
def __init__(self, agent, fn):
|
||||
self.agent = ActiveAgentContext(agent=agent, fn=fn)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
previous_agent = active_agent.get()
|
||||
|
||||
if previous_agent:
|
||||
self.agent.agent_stack = previous_agent.agent_stack + [str(self.agent)]
|
||||
else:
|
||||
self.agent.agent_stack = [str(self.agent)]
|
||||
|
||||
self.token = active_agent.set(self.agent)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
|
||||
@@ -22,7 +22,14 @@ from talemate.events import GameLoopEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import CharacterMessage, DirectorMessage
|
||||
|
||||
from .base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
|
||||
from .base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
AgentEmission,
|
||||
set_processing,
|
||||
)
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -78,14 +85,23 @@ class ConversationAgent(Agent):
|
||||
self.actions = {
|
||||
"generation_override": AgentAction(
|
||||
enabled=True,
|
||||
label="Generation Override",
|
||||
description="Override generation parameters",
|
||||
label="Generation Settings",
|
||||
config={
|
||||
"format": AgentActionConfig(
|
||||
type="text",
|
||||
label="Format",
|
||||
description="The generation format of the scene context, as seen by the AI.",
|
||||
choices=[
|
||||
{"label": "Screenplay", "value": "movie_script"},
|
||||
{"label": "Chat (legacy)", "value": "chat"},
|
||||
],
|
||||
value="movie_script",
|
||||
),
|
||||
"length": AgentActionConfig(
|
||||
type="number",
|
||||
label="Generation Length (tokens)",
|
||||
description="Maximum number of tokens to generate for a conversation response.",
|
||||
value=96,
|
||||
value=128,
|
||||
min=32,
|
||||
max=512,
|
||||
step=32,
|
||||
@@ -166,6 +182,42 @@ class ConversationAgent(Agent):
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def conversation_format(self):
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["format"].value
|
||||
return "movie_script"
|
||||
|
||||
@property
|
||||
def conversation_format_label(self):
|
||||
value = self.conversation_format
|
||||
|
||||
choices = self.actions["generation_override"].config["format"].choices
|
||||
|
||||
for choice in choices:
|
||||
if choice["value"] == value:
|
||||
return choice["label"]
|
||||
|
||||
return value
|
||||
|
||||
@property
|
||||
def agent_details(self) -> dict:
|
||||
|
||||
details = {
|
||||
"client": AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.client.name if self.client else None,
|
||||
description="The client to use for prompt generation",
|
||||
).model_dump(),
|
||||
"format": AgentDetail(
|
||||
icon="mdi-format-float-none",
|
||||
value=self.conversation_format_label,
|
||||
description="Generation format of the scene context, as seen by the AI",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
return details
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
@@ -299,7 +351,7 @@ class ConversationAgent(Agent):
|
||||
|
||||
# AI will attempt to figure out who should talk next
|
||||
next_actor = await self.select_talking_actor(character_names)
|
||||
next_actor = next_actor.strip().strip('"').strip(".")
|
||||
next_actor = next_actor.split("\n")[0].strip().strip('"').strip(".")
|
||||
|
||||
for character_name in scene.character_names:
|
||||
if (
|
||||
@@ -425,8 +477,9 @@ class ConversationAgent(Agent):
|
||||
self.actions["generation_override"].config["instructions"].value
|
||||
)
|
||||
|
||||
conversation_format = self.conversation_format
|
||||
prompt = Prompt.get(
|
||||
"conversation.dialogue",
|
||||
f"conversation.dialogue-{conversation_format}",
|
||||
vars={
|
||||
"scene": scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -440,6 +493,7 @@ class ConversationAgent(Agent):
|
||||
"partial_message": char_message,
|
||||
"director_message": director_message,
|
||||
"extra_instructions": extra_instructions,
|
||||
"decensor": self.client.decensor_enabled,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -521,11 +575,16 @@ class ConversationAgent(Agent):
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
|
||||
if "(Internal" in result:
|
||||
result = result.split("(Internal")[0]
|
||||
|
||||
result = result.replace(" :", ":")
|
||||
result = result.replace("[", "*").replace("]", "*")
|
||||
result = result.replace("(", "*").replace(")", "*")
|
||||
result = result.replace("**", "*")
|
||||
|
||||
result = util.handle_endofline_special_delimiter(result)
|
||||
|
||||
return result
|
||||
|
||||
def set_generation_overrides(self):
|
||||
@@ -605,14 +664,27 @@ class ConversationAgent(Agent):
|
||||
|
||||
result = result.replace(" :", ":")
|
||||
|
||||
total_result = total_result.split("#")[0]
|
||||
total_result = total_result.split("#")[0].strip()
|
||||
|
||||
total_result = util.handle_endofline_special_delimiter(total_result)
|
||||
|
||||
log.info("conversation agent", total_result=total_result)
|
||||
|
||||
if total_result.startswith(":\n") or total_result.startswith(": "):
|
||||
total_result = total_result[2:]
|
||||
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
|
||||
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
# Remove "{character.name}:" - all occurences
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
@@ -660,4 +732,4 @@ class ConversationAgent(Agent):
|
||||
):
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += ["["]
|
||||
prompt_param["extra_stopping_strings"] += ["#"]
|
||||
|
||||
@@ -9,13 +9,13 @@ from talemate.agents.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
from .assistant import AssistantMixin
|
||||
from .character import CharacterCreatorMixin
|
||||
from .scenario import ScenarioCreatorMixin
|
||||
|
||||
|
||||
@register()
|
||||
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, AssistantMixin, Agent):
|
||||
"""
|
||||
Creates characters and scenarios and other fun stuff!
|
||||
"""
|
||||
|
||||
182
src/talemate/agents/creator/assistant.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import pydantic
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character, Scene
|
||||
|
||||
|
||||
class ContentGenerationContext(pydantic.BaseModel):
|
||||
"""
|
||||
A context for generating content.
|
||||
"""
|
||||
|
||||
context: str
|
||||
instructions: str = ""
|
||||
length: int = 100
|
||||
character: Union[str, None] = None
|
||||
original: Union[str, None] = None
|
||||
partial: str = ""
|
||||
|
||||
@property
|
||||
def computed_context(self) -> Tuple[str, str]:
|
||||
typ, context = self.context.split(":", 1)
|
||||
return typ, context
|
||||
|
||||
|
||||
class AssistantMixin:
|
||||
"""
|
||||
Creator mixin that allows quick contextual generation of content.
|
||||
"""
|
||||
|
||||
async def contextual_generate_from_args(
|
||||
self,
|
||||
context: str,
|
||||
instructions: str = "",
|
||||
length: int = 100,
|
||||
character: Union[str, None] = None,
|
||||
original: Union[str, None] = None,
|
||||
partial: str = "",
|
||||
):
|
||||
"""
|
||||
Request content from the assistant.
|
||||
"""
|
||||
|
||||
generation_context = ContentGenerationContext(
|
||||
context=context,
|
||||
instructions=instructions,
|
||||
length=length,
|
||||
character=character,
|
||||
original=original,
|
||||
partial=partial,
|
||||
)
|
||||
|
||||
return await self.contextual_generate(generation_context)
|
||||
|
||||
contextual_generate_from_args.exposed = True
|
||||
|
||||
@set_processing
|
||||
async def contextual_generate(
|
||||
self,
|
||||
generation_context: ContentGenerationContext,
|
||||
):
|
||||
"""
|
||||
Request content from the assistant.
|
||||
"""
|
||||
|
||||
context_typ, context_name = generation_context.computed_context
|
||||
|
||||
if generation_context.length < 100:
|
||||
kind = "create_short"
|
||||
elif generation_context.length < 500:
|
||||
kind = "create_concise"
|
||||
else:
|
||||
kind = "create"
|
||||
|
||||
content = await Prompt.request(
|
||||
f"creator.contextual-generate",
|
||||
self.client,
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"generation_context": generation_context,
|
||||
"context_typ": context_typ,
|
||||
"context_name": context_name,
|
||||
"can_coerce": self.client.can_be_coerced,
|
||||
"character": (
|
||||
self.scene.get_character(generation_context.character)
|
||||
if generation_context.character
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
if not generation_context.partial:
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
@set_processing
|
||||
async def autocomplete_dialogue(
|
||||
self,
|
||||
input: str,
|
||||
character: "Character",
|
||||
emit_signal: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Autocomplete dialogue.
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-dialogue",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"input": input.strip(),
|
||||
"character": character,
|
||||
"can_coerce": self.client.can_be_coerced,
|
||||
},
|
||||
pad_prepended_response=False,
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
response = util.clean_dialogue(response, character.name)[
|
||||
len(character.name + ":") :
|
||||
].strip()
|
||||
|
||||
if response.startswith(input):
|
||||
response = response[len(input) :]
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_suggestion", suggestion=response, input=input
|
||||
)
|
||||
|
||||
if emit_signal:
|
||||
emit("autocomplete_suggestion", response)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def autocomplete_narrative(
|
||||
self,
|
||||
input: str,
|
||||
emit_signal: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Autocomplete narrative.
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-narrative",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"input": input.strip(),
|
||||
"can_coerce": self.client.can_be_coerced,
|
||||
},
|
||||
pad_prepended_response=False,
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
if response.startswith(input):
|
||||
response = response[len(input) :]
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_suggestion", suggestion=response, input=input
|
||||
)
|
||||
|
||||
if emit_signal:
|
||||
emit("autocomplete_suggestion", response)
|
||||
|
||||
return response
|
||||
@@ -193,6 +193,25 @@ class CharacterCreatorMixin:
|
||||
)
|
||||
return content_context.strip()
|
||||
|
||||
@set_processing
|
||||
async def determine_character_dialogue_instructions(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
instructions = await Prompt.request(
|
||||
f"creator.determine-character-dialogue-instructions",
|
||||
self.client,
|
||||
"create_concise",
|
||||
vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
r = instructions.strip().split("\n")[0].strip('"').strip()
|
||||
return r
|
||||
|
||||
@set_processing
|
||||
async def determine_character_attributes(
|
||||
self,
|
||||
@@ -208,6 +227,27 @@ class CharacterCreatorMixin:
|
||||
)
|
||||
return attributes
|
||||
|
||||
@set_processing
|
||||
async def determine_character_name(
|
||||
self,
|
||||
character_name: str,
|
||||
allowed_names: list[str] = None,
|
||||
group: bool = False,
|
||||
) -> str:
|
||||
name = await Prompt.request(
|
||||
f"creator.determine-character-name",
|
||||
self.client,
|
||||
"analyze_freeform_short",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character_name": character_name,
|
||||
"allowed_names": allowed_names or [],
|
||||
"group": group,
|
||||
},
|
||||
)
|
||||
return name.split('"', 1)[0].strip().strip(".").strip()
|
||||
|
||||
@set_processing
|
||||
async def determine_character_description(
|
||||
self, character: Character, text: str = ""
|
||||
|
||||
@@ -7,7 +7,6 @@ from talemate.prompts import Prompt
|
||||
|
||||
|
||||
class ScenarioCreatorMixin:
|
||||
|
||||
"""
|
||||
Adds scenario creation functionality to the creator agent
|
||||
"""
|
||||
@@ -129,4 +128,19 @@ class ScenarioCreatorMixin:
|
||||
"text": text,
|
||||
},
|
||||
)
|
||||
return description
|
||||
return description.strip()
|
||||
|
||||
@set_processing
|
||||
async def determine_content_context_for_description(
|
||||
self,
|
||||
description: str,
|
||||
):
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
return content_context.lstrip().split("\n")[0].strip('"').strip()
|
||||
|
||||
34
src/talemate/agents/custom/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger("talemate.agents.custom")
|
||||
|
||||
# import every submodule in this directory
|
||||
#
|
||||
# each directory in this directory is a submodule
|
||||
|
||||
# get the current directory
|
||||
current_directory = os.path.dirname(__file__)
|
||||
|
||||
# get all subdirectories
|
||||
subdirectories = [
|
||||
os.path.join(current_directory, name)
|
||||
for name in os.listdir(current_directory)
|
||||
if os.path.isdir(os.path.join(current_directory, name))
|
||||
]
|
||||
|
||||
# import every submodule
|
||||
|
||||
for subdirectory in subdirectories:
|
||||
# get the name of the submodule
|
||||
submodule_name = os.path.basename(subdirectory)
|
||||
|
||||
if submodule_name.startswith("__"):
|
||||
continue
|
||||
|
||||
log.info("activating custom agent", module=submodule_name)
|
||||
|
||||
# import the submodule
|
||||
importlib.import_module(f".{submodule_name}", __package__)
|
||||
@@ -0,0 +1,5 @@
|
||||
Each agent should be in its own subdirectory.
|
||||
|
||||
The subdirectory itself must be a valid python module.
|
||||
|
||||
Check out docs/dev/agents/example/test for a very simplistic custom agent example.
|
||||
@@ -15,6 +15,7 @@ from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.automated_action import AutomatedAction
|
||||
from talemate.emit import emit, wait_for_input
|
||||
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
|
||||
from talemate.game.engine import GameInstructionsMixin
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, NarratorMessage
|
||||
|
||||
@@ -28,7 +29,7 @@ log = structlog.get_logger("talemate.agent.director")
|
||||
|
||||
|
||||
@register()
|
||||
class DirectorAgent(Agent):
|
||||
class DirectorAgent(GameInstructionsMixin, Agent):
|
||||
agent_type = "director"
|
||||
verbose_name = "Director"
|
||||
|
||||
@@ -64,6 +65,22 @@ class DirectorAgent(Agent):
|
||||
description="If enabled, direction will be given to actors based on their goals.",
|
||||
value=True,
|
||||
),
|
||||
"actor_direction_mode": AgentActionConfig(
|
||||
type="text",
|
||||
label="Actor Direction Mode",
|
||||
description="The mode to use when directing actors",
|
||||
value="direction",
|
||||
choices=[
|
||||
{
|
||||
"label": "Direction",
|
||||
"value": "direction",
|
||||
},
|
||||
{
|
||||
"label": "Inner Monologue",
|
||||
"value": "internal_monologue",
|
||||
},
|
||||
],
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
@@ -80,6 +97,22 @@ class DirectorAgent(Agent):
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def direct_enabled(self):
|
||||
return self.actions["direct"].enabled
|
||||
|
||||
@property
|
||||
def direct_actors_enabled(self):
|
||||
return self.actions["direct"].config["direct_actors"].value
|
||||
|
||||
@property
|
||||
def direct_scene_enabled(self):
|
||||
return self.actions["direct"].config["direct_scene"].value
|
||||
|
||||
@property
|
||||
def actor_direction_mode(self):
|
||||
return self.actions["direct"].config["actor_direction_mode"].value
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(
|
||||
@@ -97,13 +130,13 @@ class DirectorAgent(Agent):
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
if self.scene.game_state.has_scene_instructions:
|
||||
if await self.scene_has_instructions(self.scene):
|
||||
self.is_enabled = True
|
||||
log.warning("on_scene_init - enabling director", scene=self.scene)
|
||||
else:
|
||||
return
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
if not await self.scene_has_instructions(self.scene):
|
||||
return
|
||||
|
||||
if not self.scene.game_state.ops.run_on_start:
|
||||
@@ -123,7 +156,7 @@ class DirectorAgent(Agent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
if not await self.scene_has_instructions(self.scene):
|
||||
return
|
||||
|
||||
if not event.actor.character.is_player:
|
||||
@@ -182,7 +215,10 @@ class DirectorAgent(Agent):
|
||||
|
||||
# no character, see if there are NPC characters at all
|
||||
# if not we always want to direct narration
|
||||
always_direct = not self.scene.npc_character_names
|
||||
always_direct = (
|
||||
not self.scene.npc_character_names
|
||||
or self.scene.game_state.ops.always_direct
|
||||
)
|
||||
|
||||
next_direct = self.next_direct_scene
|
||||
|
||||
@@ -205,7 +241,7 @@ class DirectorAgent(Agent):
|
||||
Run game state instructions, if they exist.
|
||||
"""
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
if not await self.scene_has_instructions(self.scene):
|
||||
return
|
||||
|
||||
await self.direct_scene(None, None)
|
||||
@@ -250,8 +286,35 @@ class DirectorAgent(Agent):
|
||||
emit("director", message, character=character)
|
||||
self.scene.push_history(message)
|
||||
else:
|
||||
# run scene instructions
|
||||
self.scene.game_state.scene_instructions
|
||||
await self.run_scene_instructions(self.scene)
|
||||
|
||||
@set_processing
|
||||
async def persist_characters_from_worldstate(
|
||||
self, exclude: list[str] = None
|
||||
) -> List[Character]:
|
||||
log.warning(
|
||||
"persist_characters_from_worldstate",
|
||||
world_state_characters=self.scene.world_state.characters,
|
||||
scene_characters=self.scene.character_names,
|
||||
)
|
||||
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
if character_name in self.scene.character_names:
|
||||
continue
|
||||
|
||||
character = await self.persist_character(name=character_name)
|
||||
|
||||
created_characters.append(character)
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
return created_characters
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
@@ -259,11 +322,17 @@ class DirectorAgent(Agent):
|
||||
name: str,
|
||||
content: str = None,
|
||||
attributes: str = None,
|
||||
determine_name: bool = True,
|
||||
):
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
|
||||
self.scene.log.debug("persist_character", name=name)
|
||||
|
||||
if determine_name:
|
||||
name = await creator.determine_character_name(name)
|
||||
self.scene.log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
character = self.scene.Character(name=name)
|
||||
character.color = random.choice(
|
||||
[
|
||||
@@ -297,6 +366,16 @@ class DirectorAgent(Agent):
|
||||
|
||||
self.scene.log.debug("persist_character", description=description)
|
||||
|
||||
dialogue_instructions = await creator.determine_character_dialogue_instructions(
|
||||
character
|
||||
)
|
||||
|
||||
character.dialogue_instructions = dialogue_instructions
|
||||
|
||||
self.scene.log.debug(
|
||||
"persist_character", dialogue_instructions=dialogue_instructions
|
||||
)
|
||||
|
||||
actor = self.scene.Actor(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
@@ -328,6 +407,13 @@ class DirectorAgent(Agent):
|
||||
self.scene.context = response.strip()
|
||||
self.scene.emit_status()
|
||||
|
||||
async def log_action(self, action: str, action_description: str):
|
||||
message = DirectorMessage(message=action_description, action=action)
|
||||
self.scene.push_history(message)
|
||||
emit("director", message)
|
||||
|
||||
log_action.exposed = True
|
||||
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
|
||||
@@ -40,11 +40,6 @@ class EditorAgent(Agent):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"edit_dialogue": AgentAction(
|
||||
enabled=False,
|
||||
label="Edit dialogue",
|
||||
description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
"fix_exposition": AgentAction(
|
||||
enabled=True,
|
||||
label="Fix exposition",
|
||||
@@ -63,6 +58,11 @@ class EditorAgent(Agent):
|
||||
label="Add detail",
|
||||
description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
"check_continuity_errors": AgentAction(
|
||||
enabled=False,
|
||||
label="Check continuity errors",
|
||||
description="Will attempt to fix continuity errors in the dialogue. Runs automatically after each AI dialogue. (super experimental)",
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -100,10 +100,10 @@ class EditorAgent(Agent):
|
||||
for text in emission.generation:
|
||||
edit = await self.add_detail(text, emission.character)
|
||||
|
||||
edit = await self.edit_conversation(edit, emission.character)
|
||||
|
||||
edit = await self.fix_exposition(edit, emission.character)
|
||||
|
||||
edit = await self.check_continuity_errors(edit, emission.character)
|
||||
|
||||
edited.append(edit)
|
||||
|
||||
emission.generation = edited
|
||||
@@ -126,35 +126,6 @@ class EditorAgent(Agent):
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
@set_processing
|
||||
async def edit_conversation(self, content: str, character: Character):
|
||||
"""
|
||||
Edits a conversation
|
||||
"""
|
||||
|
||||
if not self.actions["edit_dialogue"].enabled:
|
||||
return content
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.edit-dialogue",
|
||||
self.client,
|
||||
"edit_dialogue",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
response = response.split("[end]")[0]
|
||||
|
||||
response = util.replace_exposition_markers(response)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def fix_exposition(self, content: str, character: Character):
|
||||
"""
|
||||
@@ -169,7 +140,7 @@ class EditorAgent(Agent):
|
||||
content = util.strip_partial_sentences(content)
|
||||
character_prefix = f"{character.name}: "
|
||||
message = content.split(character_prefix)[1]
|
||||
content = f"{character_prefix}*{message.strip('*')}*"
|
||||
content = f'{character_prefix}"{message.strip()}"'
|
||||
return content
|
||||
elif '"' in content:
|
||||
# silly hack to clean up some LLMs that always start with a quote
|
||||
@@ -227,3 +198,114 @@ class EditorAgent(Agent):
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def check_continuity_errors(
|
||||
self,
|
||||
content: str,
|
||||
character: Character,
|
||||
force: bool = False,
|
||||
fix: bool = True,
|
||||
message_id: int = None,
|
||||
) -> str:
|
||||
"""
|
||||
Edits a text to ensure that it is consistent with the scene
|
||||
so far
|
||||
"""
|
||||
|
||||
if not self.actions["check_continuity_errors"].enabled and not force:
|
||||
return content
|
||||
|
||||
MAX_CONTENT_LENGTH = 255
|
||||
count = util.count_tokens(content)
|
||||
|
||||
if count > MAX_CONTENT_LENGTH:
|
||||
log.warning(
|
||||
"check_continuity_errors content too long",
|
||||
length=count,
|
||||
max=MAX_CONTENT_LENGTH,
|
||||
content=content[:255],
|
||||
)
|
||||
return content
|
||||
|
||||
log.debug(
|
||||
"check_continuity_errors START",
|
||||
content=content,
|
||||
character=character,
|
||||
force=force,
|
||||
fix=fix,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.check-continuity-errors",
|
||||
self.client,
|
||||
"basic_analytical_medium2",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
|
||||
# loop through response line by line, checking for lines beginning
|
||||
# with "ERROR {number}:
|
||||
|
||||
errors = []
|
||||
|
||||
for line in response.split("\n"):
|
||||
if "ERROR" not in line:
|
||||
continue
|
||||
|
||||
errors.append(line)
|
||||
|
||||
if not errors:
|
||||
log.debug("check_continuity_errors NO ERRORS")
|
||||
return content
|
||||
|
||||
log.debug("check_continuity_errors ERRORS", fix=fix, errors=errors)
|
||||
|
||||
if not fix:
|
||||
return content
|
||||
|
||||
state = {}
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.fix-continuity-errors",
|
||||
self.client,
|
||||
"editor_creative_medium2",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"errors": errors,
|
||||
"set_state": lambda k, v: state.update({k: v}),
|
||||
},
|
||||
)
|
||||
|
||||
content_fix_identifer = state.get("content_fix_identifier")
|
||||
|
||||
try:
|
||||
content = response.strip().strip("```").split("```")[0].strip()
|
||||
content = content.replace(content_fix_identifer, "").strip()
|
||||
content = content.strip(":")
|
||||
|
||||
# if content doesnt start with {character_name}: then add it
|
||||
if not content.startswith(f"{character.name}:"):
|
||||
content = f"{character.name}: {content}"
|
||||
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"check_continuity_errors FAILED",
|
||||
content_fix_identifer=content_fix_identifer,
|
||||
response=response,
|
||||
e=e,
|
||||
)
|
||||
return content
|
||||
|
||||
log.debug("check_continuity_errors FIXED", content=content)
|
||||
|
||||
return content
|
||||
|
||||
@@ -30,7 +30,7 @@ if not chromadb:
|
||||
log.info("ChromaDB not found, disabling Chroma agent")
|
||||
|
||||
|
||||
from .base import Agent
|
||||
from .base import Agent, AgentDetail
|
||||
|
||||
|
||||
class MemoryDocument(str):
|
||||
@@ -368,10 +368,30 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return "No OpenAI API key set"
|
||||
|
||||
return f"ChromaDB: {self.embeddings}"
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value="ChromaDB",
|
||||
description="The backend to use for long-term memory",
|
||||
).model_dump(),
|
||||
"embeddings": AgentDetail(
|
||||
icon="mdi-cube-unfolded",
|
||||
value=self.embeddings,
|
||||
description="The embeddings model.",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
# return "No OpenAI API key set"
|
||||
details["error"] = {
|
||||
"icon": "mdi-alert",
|
||||
"value": "No OpenAI API key set",
|
||||
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
|
||||
"color": "error",
|
||||
}
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
@@ -425,7 +445,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
def make_collection_name(self, scene):
|
||||
if self.USE_OPENAI:
|
||||
suffix = "-openai"
|
||||
model_name = self.config.get("chromadb").get(
|
||||
"openai_model", "text-embedding-3-small"
|
||||
)
|
||||
if model_name == "text-embedding-ada-002":
|
||||
suffix = "-openai"
|
||||
else:
|
||||
suffix = f"-openai-{model_name}"
|
||||
elif self.USE_INSTRUCTOR:
|
||||
suffix = "-instructor"
|
||||
model = self.config.get("chromadb").get(
|
||||
@@ -472,12 +498,19 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
"You must provide an the openai ai key in the config if you want to use it for chromadb embeddings"
|
||||
)
|
||||
|
||||
model_name = self.config.get("chromadb").get(
|
||||
"openai_model", "text-embedding-3-small"
|
||||
)
|
||||
|
||||
log.info(
|
||||
"crhomadb", status="using openai", openai_key=openai_key[:5] + "..."
|
||||
"crhomadb",
|
||||
status="using openai",
|
||||
openai_key=openai_key[:5] + "...",
|
||||
model=model_name,
|
||||
)
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=openai_key,
|
||||
model_name="text-embedding-ada-002",
|
||||
model_name=model_name,
|
||||
)
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=openai_ef
|
||||
@@ -687,6 +720,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
doc = _results["documents"][0][i]
|
||||
meta = _results["metadatas"][0][i]
|
||||
|
||||
if not meta:
|
||||
log.warning("chromadb agent get", error="no meta", doc=doc)
|
||||
continue
|
||||
|
||||
ts = meta.get("ts")
|
||||
|
||||
# skip pin_only entries
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import random
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
@@ -40,7 +41,8 @@ def set_processing(fn):
|
||||
"""
|
||||
|
||||
@_set_processing
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
@wraps(fn)
|
||||
async def narration_wrapper(self, *args, **kwargs):
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission = NarratorAgentEmission(
|
||||
agent=self,
|
||||
@@ -49,13 +51,11 @@ def set_processing(fn):
|
||||
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
|
||||
return emission.generation[0]
|
||||
|
||||
wrapper.__name__ = fn.__name__
|
||||
return wrapper
|
||||
return narration_wrapper
|
||||
|
||||
|
||||
@register()
|
||||
class NarratorAgent(Agent):
|
||||
|
||||
"""
|
||||
Handles narration of the story
|
||||
"""
|
||||
@@ -524,23 +524,102 @@ class NarratorAgent(Agent):
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def paraphrase(self, narration: str):
|
||||
"""
|
||||
Paraphrase a narration
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.paraphrase",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars={
|
||||
"text": narration,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
log.info("paraphrase", narration=narration, response=response)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
async def passthrough(self, narration: str) -> str:
|
||||
"""
|
||||
Pass through narration message as is
|
||||
"""
|
||||
narration = narration.replace("*", "")
|
||||
narration = f"*{narration}*"
|
||||
narration = util.ensure_dialog_format(narration)
|
||||
return narration
|
||||
|
||||
def action_to_source(
|
||||
self,
|
||||
action_name: str,
|
||||
parameters: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a source string for a given action and parameters
|
||||
|
||||
The source string is used to identify the source of a NarratorMessage
|
||||
and will also help regenerate the action and parameters from the source string
|
||||
later on
|
||||
"""
|
||||
|
||||
args = []
|
||||
|
||||
if action_name == "paraphrase":
|
||||
args.append(parameters.get("narration"))
|
||||
elif action_name == "narrate_character_entry":
|
||||
args.append(parameters.get("character").name)
|
||||
# args.append(parameters.get("direction"))
|
||||
elif action_name == "narrate_character_exit":
|
||||
args.append(parameters.get("character").name)
|
||||
# args.append(parameters.get("direction"))
|
||||
elif action_name == "narrate_character":
|
||||
args.append(parameters.get("character").name)
|
||||
elif action_name == "narrate_query":
|
||||
args.append(parameters.get("query"))
|
||||
elif action_name == "narrate_time_passage":
|
||||
args.append(parameters.get("duration"))
|
||||
args.append(parameters.get("time_passed"))
|
||||
args.append(parameters.get("narrative"))
|
||||
elif action_name == "progress_story":
|
||||
args.append(parameters.get("narrative_direction"))
|
||||
elif action_name == "narrate_after_dialogue":
|
||||
args.append(parameters.get("character"))
|
||||
|
||||
arg_str = ";".join(args) if args else ""
|
||||
|
||||
return f"{action_name}:{arg_str}".rstrip(":")
|
||||
|
||||
async def action_to_narration(
|
||||
self,
|
||||
action_name: str,
|
||||
*args,
|
||||
emit_message: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# calls self[action_name] and returns the result as a NarratorMessage
|
||||
# that is pushed to the history
|
||||
|
||||
fn = getattr(self, action_name)
|
||||
narration = await fn(*args, **kwargs)
|
||||
narrator_message = NarratorMessage(
|
||||
narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":")
|
||||
)
|
||||
narration = await fn(**kwargs)
|
||||
source = self.action_to_source(action_name, kwargs)
|
||||
|
||||
narrator_message = NarratorMessage(narration, source=source)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
if emit_message:
|
||||
emit("narrator", narrator_message)
|
||||
|
||||
return narrator_message
|
||||
|
||||
action_to_narration.exposed = True
|
||||
|
||||
# LLM client related methods. These are called during or after the client
|
||||
|
||||
def inject_prompt_paramters(
|
||||
|
||||
@@ -61,6 +61,7 @@ class SummarizeAgent(Agent):
|
||||
{"label": "Short & Concise", "value": "short"},
|
||||
{"label": "Balanced", "value": "balanced"},
|
||||
{"label": "Lengthy & Detailed", "value": "long"},
|
||||
{"label": "Factual List", "value": "facts"},
|
||||
],
|
||||
),
|
||||
"include_previous": AgentActionConfig(
|
||||
@@ -77,6 +78,15 @@ class SummarizeAgent(Agent):
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def threshold(self):
|
||||
return self.actions["archive"].config["threshold"].value
|
||||
|
||||
@property
|
||||
def estimated_entry_count(self):
|
||||
all_tokens = sum([util.count_tokens(entry) for entry in self.scene.history])
|
||||
return all_tokens // self.threshold
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
@@ -140,7 +150,9 @@ class SummarizeAgent(Agent):
|
||||
if recent_entry:
|
||||
ts = recent_entry.get("ts", ts)
|
||||
|
||||
for i in range(start, len(scene.history)):
|
||||
# we ignore the most recent entry, as the user may still chose to
|
||||
# regenerate it
|
||||
for i in range(start, max(start, len(scene.history) - 1)):
|
||||
dialogue = scene.history[i]
|
||||
|
||||
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
@@ -262,9 +274,11 @@ class SummarizeAgent(Agent):
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"summarization_method": self.actions["archive"].config["method"].value
|
||||
if method is None
|
||||
else method,
|
||||
"summarization_method": (
|
||||
self.actions["archive"].config["method"].value
|
||||
if method is None
|
||||
else method
|
||||
),
|
||||
"extra_context": extra_context or "",
|
||||
"extra_instructions": extra_instructions or "",
|
||||
},
|
||||
|
||||
@@ -15,6 +15,7 @@ import nltk
|
||||
import pydantic
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import talemate.config as config
|
||||
import talemate.emit.async_signals
|
||||
@@ -24,7 +25,14 @@ from talemate.emit.signals import handlers
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import CharacterMessage, NarratorMessage
|
||||
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from .registry import register
|
||||
|
||||
try:
|
||||
@@ -109,7 +117,6 @@ class VoiceLibrary(pydantic.BaseModel):
|
||||
|
||||
@register()
|
||||
class TTSAgent(Agent):
|
||||
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
@@ -117,6 +124,7 @@ class TTSAgent(Agent):
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
essential = False
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
@@ -135,11 +143,12 @@ class TTSAgent(Agent):
|
||||
|
||||
self.voices = {
|
||||
"elevenlabs": VoiceLibrary(api="elevenlabs"),
|
||||
"coqui": VoiceLibrary(api="coqui"),
|
||||
"tts": VoiceLibrary(api="tts"),
|
||||
"openai": VoiceLibrary(api="openai"),
|
||||
}
|
||||
self.config = config.load_config()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
self.preselect_voice = None
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
@@ -149,10 +158,9 @@ class TTSAgent(Agent):
|
||||
"api": AgentActionConfig(
|
||||
type="text",
|
||||
choices=[
|
||||
# TODO at local TTS support
|
||||
{"value": "tts", "label": "TTS (Local)"},
|
||||
{"value": "elevenlabs", "label": "Eleven Labs"},
|
||||
{"value": "coqui", "label": "Coqui Studio"},
|
||||
{"value": "openai", "label": "OpenAI"},
|
||||
],
|
||||
value="tts",
|
||||
label="API",
|
||||
@@ -192,6 +200,25 @@ class TTSAgent(Agent):
|
||||
),
|
||||
},
|
||||
),
|
||||
"openai": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.api", value="openai"
|
||||
),
|
||||
label="OpenAI Settings",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="tts-1",
|
||||
choices=[
|
||||
{"value": "tts-1", "label": "TTS 1"},
|
||||
{"value": "tts-1-hd", "label": "TTS 1 HD"},
|
||||
],
|
||||
label="Model",
|
||||
description="TTS model to use",
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
self.actions["_config"].model_dump()
|
||||
@@ -230,27 +257,45 @@ class TTSAgent(Agent):
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
suffix = ""
|
||||
|
||||
if not self.ready:
|
||||
suffix = f" - {self.not_ready_reason}"
|
||||
else:
|
||||
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
|
||||
details = {
|
||||
"api": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.api_label,
|
||||
description="The backend to use for TTS",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
api = self.api
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api_label = api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
api_label = choice["label"]
|
||||
break
|
||||
if self.ready and self.enabled:
|
||||
details["voice"] = AgentDetail(
|
||||
icon="mdi-account-voice",
|
||||
value=self.voice_id_to_label(self.default_voice_id) or "",
|
||||
description="The voice to use for TTS",
|
||||
color="info",
|
||||
).model_dump()
|
||||
elif self.enabled:
|
||||
details["error"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=self.not_ready_reason,
|
||||
description=self.not_ready_reason,
|
||||
color="error",
|
||||
).model_dump()
|
||||
|
||||
return f"{api_label}{suffix}"
|
||||
return details
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self.actions["_config"].config["api"].value
|
||||
|
||||
@property
|
||||
def api_label(self):
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api = self.api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
return choice["label"]
|
||||
return api
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
api = self.api
|
||||
@@ -278,6 +323,8 @@ class TTSAgent(Agent):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg" if not getattr(self, "processing", False) else "busy"
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
if self.requires_token and not self.token:
|
||||
return "error"
|
||||
@@ -295,7 +342,11 @@ class TTSAgent(Agent):
|
||||
|
||||
return 250
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
try:
|
||||
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
|
||||
except KeyError:
|
||||
@@ -304,10 +355,22 @@ class TTSAgent(Agent):
|
||||
api_changed = api != self.api
|
||||
|
||||
log.debug(
|
||||
"apply_config", api=api, api_changed=api != self.api, current_api=self.api
|
||||
"apply_config",
|
||||
api=api,
|
||||
api_changed=api != self.api,
|
||||
current_api=self.api,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
super().apply_config(*args, **kwargs)
|
||||
try:
|
||||
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
|
||||
"value"
|
||||
]
|
||||
except KeyError:
|
||||
self.preselect_voice = self.default_voice_id
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
if api_changed:
|
||||
try:
|
||||
@@ -400,6 +463,11 @@ class TTSAgent(Agent):
|
||||
library.voices = await list_fn()
|
||||
library.last_synced = time.time()
|
||||
|
||||
if self.preselect_voice:
|
||||
if self.voice(self.preselect_voice):
|
||||
self.actions["_config"].config["voice_id"].value = self.preselect_voice
|
||||
self.preselect_voice = None
|
||||
|
||||
# if the current voice cannot be found, reset it
|
||||
if not self.voice(self.default_voice_id):
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
@@ -425,9 +493,10 @@ class TTSAgent(Agent):
|
||||
|
||||
# Start generating audio chunks in the background
|
||||
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
|
||||
await self.set_background_processing(generation_task)
|
||||
|
||||
# Wait for both tasks to complete
|
||||
await asyncio.gather(generation_task)
|
||||
# await asyncio.gather(generation_task)
|
||||
|
||||
async def generate_chunks(self, generate_fn, chunks):
|
||||
for chunk in chunks:
|
||||
@@ -552,96 +621,32 @@ class TTSAgent(Agent):
|
||||
|
||||
return voices
|
||||
|
||||
# COQUI STUDIO
|
||||
# OPENAI
|
||||
|
||||
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
async def _generate_openai(self, text: str, chunk_size: int = 1024):
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = "https://app.coqui.ai/api/v2/samples/xtts/render/"
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
data = {
|
||||
"voice_id": self.default_voice_id,
|
||||
"text": text,
|
||||
"language": "en", # Assuming English language for simplicity; this could be parameterized
|
||||
}
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
# Make the POST request to Coqui API
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
if response.status_code in [200, 201]:
|
||||
# Parse the JSON response to get the audio URL
|
||||
response_data = response.json()
|
||||
audio_url = response_data.get("audio_url")
|
||||
if audio_url:
|
||||
# Make a GET request to download the audio file
|
||||
audio_response = await client.get(audio_url)
|
||||
if audio_response.status_code == 200:
|
||||
# delete the sample from Coqui Studio
|
||||
# await self._cleanup_coqui(response_data.get('id'))
|
||||
return audio_response.content
|
||||
else:
|
||||
log.error(f"Error downloading audio: {audio_response.text}")
|
||||
else:
|
||||
log.error("No audio URL in response")
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
model = self.actions["openai"].config["model"].value
|
||||
|
||||
async def _cleanup_coqui(self, sample_id: str):
|
||||
api_key = self.token
|
||||
if not api_key or not sample_id:
|
||||
return
|
||||
response = await client.audio.speech.create(
|
||||
model=model, voice=self.default_voice_id, input=text
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Make the DELETE request to Coqui API
|
||||
response = await client.delete(url, headers=headers)
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
|
||||
if response.status_code == 204:
|
||||
log.info(f"Successfully deleted sample with ID: {sample_id}")
|
||||
else:
|
||||
log.error(
|
||||
f"Error deleting sample with ID: {sample_id}: {response.text}"
|
||||
)
|
||||
|
||||
async def _list_voices_coqui(self) -> dict[str, str]:
|
||||
url_speakers = "https://app.coqui.ai/api/v2/speakers"
|
||||
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
|
||||
|
||||
voices = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {"Authorization": f"Bearer {self.token}"}
|
||||
response = await client.get(
|
||||
url_speakers, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
speakers = response.json()["result"]
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=speaker["id"], label=speaker["name"])
|
||||
for speaker in speakers
|
||||
]
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
url_custom_voices, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
custom_voices = response.json()["result"]
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=voice["id"], label=voice["name"])
|
||||
for voice in custom_voices
|
||||
]
|
||||
)
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
async def _list_voices_openai(self) -> dict[str, str]:
|
||||
return [
|
||||
Voice(value="alloy", label="Alloy"),
|
||||
Voice(value="echo", label="Echo"),
|
||||
Voice(value="fable", label="Fable"),
|
||||
Voice(value="onyx", label="Onyx"),
|
||||
Voice(value="nova", label="Nova"),
|
||||
Voice(value="shimmer", label="Shimmer"),
|
||||
]
|
||||
|
||||
467
src/talemate/agents/visual/__init__.py
Normal file
@@ -0,0 +1,467 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.agents.visual.automatic1111
|
||||
import talemate.agents.visual.comfyui
|
||||
import talemate.agents.visual.openai_image
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.agents.registry import register
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers as signal_handlers
|
||||
from talemate.prompts.base import Prompt
|
||||
|
||||
from .commands import * # noqa
|
||||
from .context import VIS_TYPES, VisualContext, visual_context
|
||||
from .handlers import HANDLERS
|
||||
from .schema import RESOLUTION_MAP, RenderSettings
|
||||
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
|
||||
from .websocket_handler import VisualWebsocketHandler
|
||||
|
||||
__all__ = [
|
||||
"VisualAgent",
|
||||
]
|
||||
|
||||
BACKENDS = [
|
||||
{"value": mixin_backend, "label": mixin["label"]}
|
||||
for mixin_backend, mixin in HANDLERS.items()
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual")
|
||||
|
||||
|
||||
class VisualBase(Agent):
|
||||
"""
|
||||
The visual agent
|
||||
"""
|
||||
|
||||
agent_type = "visual"
|
||||
verbose_name = "Visualizer"
|
||||
essential = False
|
||||
websocket_handler = VisualWebsocketHandler
|
||||
|
||||
ACTIONS = {}
|
||||
|
||||
def __init__(self, client: ClientBase, *kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = False
|
||||
self.backend_ready = False
|
||||
self.initialized = False
|
||||
self.config = load_config()
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="Visual agent configuration",
|
||||
config={
|
||||
"backend": AgentActionConfig(
|
||||
type="text",
|
||||
choices=BACKENDS,
|
||||
value="automatic1111",
|
||||
label="Backend",
|
||||
description="The backend to use for visual processing",
|
||||
),
|
||||
"default_style": AgentActionConfig(
|
||||
type="text",
|
||||
value="graphic_novel",
|
||||
choices=MAJOR_STYLES,
|
||||
label="Default Style",
|
||||
description="The default style to use for visual processing",
|
||||
),
|
||||
},
|
||||
),
|
||||
"automatic_generation": AgentAction(
|
||||
enabled=False,
|
||||
label="Automatic Generation",
|
||||
description="Allow automatic generation of visual content",
|
||||
),
|
||||
"process_in_background": AgentAction(
|
||||
enabled=True,
|
||||
label="Process in Background",
|
||||
description="Process renders in the background",
|
||||
),
|
||||
}
|
||||
|
||||
for action_name, action in self.ACTIONS.items():
|
||||
self.actions[action_name] = action
|
||||
|
||||
signal_handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self.actions["_config"].config["backend"].value
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
key = self.actions["_config"].config["backend"].value
|
||||
|
||||
for backend in BACKENDS:
|
||||
if backend["value"] == key:
|
||||
return backend["label"]
|
||||
|
||||
@property
|
||||
def default_style(self):
|
||||
return STYLE_MAP.get(
|
||||
self.actions["_config"].config["default_style"].value, Style()
|
||||
)
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
return self.backend_ready
|
||||
|
||||
@property
|
||||
def api_url(self):
|
||||
try:
|
||||
return self.actions[self.backend].config["api_url"].value
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.backend_name,
|
||||
description="The backend to use for visual processing",
|
||||
).model_dump(),
|
||||
"client": AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.client.name if self.client else None,
|
||||
description="The client to use for prompt generation",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if not self.ready and self.enabled:
|
||||
details["status"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=f"{self.backend_name} not ready",
|
||||
color="error",
|
||||
description=self.ready_check_error
|
||||
or f"{self.backend_name} is not ready for processing",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def process_in_background(self):
|
||||
return self.actions["process_in_background"].enabled
|
||||
|
||||
@property
|
||||
def allow_automatic_generation(self):
|
||||
return self.actions["automatic_generation"].enabled
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
asyncio.create_task(self.emit_status())
|
||||
|
||||
async def on_ready_check_success(self):
|
||||
prev_ready = self.backend_ready
|
||||
self.backend_ready = True
|
||||
if not prev_ready:
|
||||
await self.emit_status()
|
||||
|
||||
async def on_ready_check_failure(self, error):
|
||||
prev_ready = self.backend_ready
|
||||
self.backend_ready = False
|
||||
self.ready_check_error = str(error)
|
||||
if prev_ready:
|
||||
await self.emit_status()
|
||||
|
||||
async def ready_check(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
backend = self.backend
|
||||
fn = getattr(self, f"{backend.lower()}_ready", None)
|
||||
task = asyncio.create_task(fn())
|
||||
await super().ready_check(task)
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
|
||||
try:
|
||||
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
|
||||
except KeyError:
|
||||
backend = self.backend
|
||||
|
||||
backend_changed = backend != self.backend
|
||||
was_disabled = not self.enabled
|
||||
|
||||
if backend_changed:
|
||||
self.backend_ready = False
|
||||
|
||||
log.info(
|
||||
"apply_config",
|
||||
backend=backend,
|
||||
backend_changed=backend_changed,
|
||||
old_backend=self.backend,
|
||||
)
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
backend_fn = getattr(self, f"{self.backend.lower()}_apply_config", None)
|
||||
if backend_fn:
|
||||
|
||||
if not backend_changed and was_disabled and self.enabled:
|
||||
# If the backend has not changed, but the agent was previously disabled
|
||||
# and is now enabled, we need to trigger the backend apply_config function
|
||||
backend_changed = True
|
||||
|
||||
task = asyncio.create_task(
|
||||
backend_fn(backend_changed=backend_changed, *args, **kwargs)
|
||||
)
|
||||
await self.set_background_processing(task)
|
||||
|
||||
if not self.backend_ready:
|
||||
await self.ready_check()
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def resolution_from_format(self, format: str, model_type: str = "sdxl"):
|
||||
if model_type not in RESOLUTION_MAP:
|
||||
raise ValueError(f"Model type {model_type} not found in resolution map")
|
||||
return RESOLUTION_MAP[model_type].get(
|
||||
format, RESOLUTION_MAP[model_type]["portrait"]
|
||||
)
|
||||
|
||||
def prepare_prompt(self, prompt: str, styles: list[Style] = None) -> Style:
|
||||
|
||||
prompt_style = Style()
|
||||
prompt_style.load(prompt)
|
||||
|
||||
if styles:
|
||||
prompt_style.prepend(*styles)
|
||||
|
||||
return prompt_style
|
||||
|
||||
def vis_type_styles(self, vis_type: str):
|
||||
if vis_type == VIS_TYPES.CHARACTER:
|
||||
portrait_style = STYLE_MAP["character_portrait"].copy()
|
||||
return portrait_style
|
||||
elif vis_type == VIS_TYPES.ENVIRONMENT:
|
||||
environment_style = STYLE_MAP["environment"].copy()
|
||||
return environment_style
|
||||
return Style()
|
||||
|
||||
async def apply_image(self, image: str):
|
||||
context = visual_context.get()
|
||||
|
||||
log.debug("apply_image", image=image[:100], context=context)
|
||||
|
||||
if context.vis_type == VIS_TYPES.CHARACTER:
|
||||
await self.apply_image_character(image, context.character_name)
|
||||
|
||||
async def apply_image_character(self, image: str, character_name: str):
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
if not character:
|
||||
log.error("character not found", character_name=character_name)
|
||||
return
|
||||
|
||||
if character.cover_image:
|
||||
log.info("character cover image already set", character_name=character_name)
|
||||
return
|
||||
|
||||
asset = self.scene.assets.add_asset_from_image_data(
|
||||
f"data:image/png;base64,{image}"
|
||||
)
|
||||
character.cover_image = asset.id
|
||||
self.scene.assets.cover_image = asset.id
|
||||
self.scene.emit_status()
|
||||
|
||||
async def emit_image(self, image: str):
|
||||
context = visual_context.get()
|
||||
await self.apply_image(image)
|
||||
emit(
|
||||
"image_generated",
|
||||
websocket_passthrough=True,
|
||||
data={
|
||||
"base64": image,
|
||||
"context": context.model_dump() if context else None,
|
||||
},
|
||||
)
|
||||
|
||||
@set_processing
|
||||
async def generate(
|
||||
self, format: str = "portrait", prompt: str = None, automatic: bool = False
|
||||
):
|
||||
|
||||
context = visual_context.get()
|
||||
|
||||
if not self.enabled:
|
||||
log.warning("generate", skipped="Visual agent not enabled")
|
||||
return
|
||||
|
||||
if automatic and not self.allow_automatic_generation:
|
||||
log.warning(
|
||||
"generate",
|
||||
skipped="Automatic generation disabled",
|
||||
prompt=prompt,
|
||||
format=format,
|
||||
context=context,
|
||||
)
|
||||
return
|
||||
|
||||
if not context and not prompt:
|
||||
log.error("generate", error="No context or prompt provided")
|
||||
return
|
||||
|
||||
# Handle prompt generation based on context
|
||||
|
||||
if not prompt and context.prompt:
|
||||
prompt = context.prompt
|
||||
|
||||
if context.vis_type == VIS_TYPES.ENVIRONMENT and not prompt:
|
||||
prompt = await self.generate_environment_prompt(
|
||||
instructions=context.instructions
|
||||
)
|
||||
elif context.vis_type == VIS_TYPES.CHARACTER and not prompt:
|
||||
prompt = await self.generate_character_prompt(
|
||||
context.character_name, instructions=context.instructions
|
||||
)
|
||||
else:
|
||||
prompt = prompt or context.prompt
|
||||
|
||||
initial_prompt = prompt
|
||||
|
||||
# Augment the prompt with styles based on context
|
||||
|
||||
thematic_style = self.default_style
|
||||
vis_type_styles = self.vis_type_styles(context.vis_type)
|
||||
prompt = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
|
||||
|
||||
if context.vis_type == VIS_TYPES.CHARACTER:
|
||||
prompt.keywords.append("character portrait")
|
||||
|
||||
if not prompt:
|
||||
log.error(
|
||||
"generate", error="No prompt provided and no context to generate from"
|
||||
)
|
||||
return
|
||||
|
||||
context.prompt = initial_prompt
|
||||
context.prepared_prompt = str(prompt)
|
||||
|
||||
# Handle format (can either come from context or be passed in)
|
||||
|
||||
if not format and context.format:
|
||||
format = context.format
|
||||
elif not format:
|
||||
format = "portrait"
|
||||
|
||||
context.format = format
|
||||
|
||||
# Call the backend specific generate function
|
||||
|
||||
backend = self.backend
|
||||
fn = f"{backend.lower()}_generate"
|
||||
|
||||
log.info(
|
||||
"generate", backend=backend, prompt=prompt, format=format, context=context
|
||||
)
|
||||
|
||||
if not hasattr(self, fn):
|
||||
log.error("generate", error=f"Backend {backend} does not support generate")
|
||||
|
||||
# add the function call to the asyncio task queue
|
||||
|
||||
if self.process_in_background:
|
||||
task = asyncio.create_task(getattr(self, fn)(prompt=prompt, format=format))
|
||||
await self.set_background_processing(task)
|
||||
else:
|
||||
await getattr(self, fn)(prompt=prompt, format=format)
|
||||
|
||||
@set_processing
|
||||
async def generate_environment_prompt(self, instructions: str = None):
|
||||
|
||||
response = await Prompt.request(
|
||||
"visual.generate-environment-prompt",
|
||||
self.client,
|
||||
"visualize",
|
||||
{
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
@set_processing
|
||||
async def generate_character_prompt(
|
||||
self, character_name: str, instructions: str = None
|
||||
):
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
response = await Prompt.request(
|
||||
"visual.generate-character-prompt",
|
||||
self.client,
|
||||
"visualize",
|
||||
{
|
||||
"scene": self.scene,
|
||||
"character_name": character_name,
|
||||
"character": character,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"instructions": instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
async def generate_environment_background(self, instructions: str = None):
|
||||
with VisualContext(vis_type=VIS_TYPES.ENVIRONMENT, instructions=instructions):
|
||||
await self.generate(format="landscape")
|
||||
|
||||
generate_environment_background.exposed = True
|
||||
|
||||
async def generate_character_portrait(
|
||||
self,
|
||||
character_name: str,
|
||||
instructions: str = None,
|
||||
):
|
||||
with VisualContext(
|
||||
vis_type=VIS_TYPES.CHARACTER,
|
||||
character_name=character_name,
|
||||
instructions=instructions,
|
||||
):
|
||||
await self.generate(format="portrait")
|
||||
|
||||
generate_character_portrait.exposed = True
|
||||
|
||||
|
||||
# apply mixins to the agent (from HANDLERS dict[str, cls])
|
||||
|
||||
for mixin_backend, mixin in HANDLERS.items():
|
||||
mixin_cls = mixin["cls"]
|
||||
VisualBase = type("VisualAgent", (mixin_cls, VisualBase), {})
|
||||
|
||||
extend_actions = getattr(mixin_cls, "EXTEND_ACTIONS", {})
|
||||
|
||||
for action_name, action in extend_actions.items():
|
||||
VisualBase.ACTIONS[action_name] = action
|
||||
|
||||
|
||||
@register()
|
||||
class VisualAgent(VisualBase):
|
||||
pass
|
||||
117
src/talemate/agents/visual/automatic1111.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.automatic1111")
|
||||
|
||||
|
||||
@register(backend_name="automatic1111", label="AUTOMATIC1111")
|
||||
class Automatic1111Mixin:
|
||||
|
||||
automatic1111_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"automatic1111": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="automatic1111"
|
||||
),
|
||||
label="Automatic1111 Settings",
|
||||
description="Setting overrides for the automatic1111 backend",
|
||||
config={
|
||||
"api_url": AgentActionConfig(
|
||||
type="text",
|
||||
value="http://localhost:7860",
|
||||
label="API URL",
|
||||
description="The URL of the backend API",
|
||||
),
|
||||
"steps": AgentActionConfig(
|
||||
type="number",
|
||||
value=40,
|
||||
label="Steps",
|
||||
min=5,
|
||||
max=150,
|
||||
step=1,
|
||||
description="number of render steps",
|
||||
),
|
||||
"model_type": AgentActionConfig(
|
||||
type="text",
|
||||
value="sdxl",
|
||||
choices=[
|
||||
{"value": "sdxl", "label": "SDXL"},
|
||||
{"value": "sd15", "label": "SD1.5"},
|
||||
],
|
||||
label="Model Type",
|
||||
description="Right now just differentiates between sdxl and sd15 - affect generation resolution",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def automatic1111_render_settings(self):
|
||||
if self.actions["automatic1111"].enabled:
|
||||
return RenderSettings(
|
||||
steps=self.actions["automatic1111"].config["steps"].value,
|
||||
type_model=self.actions["automatic1111"].config["model_type"].value,
|
||||
)
|
||||
else:
|
||||
return self.automatic1111_default_render_settings
|
||||
|
||||
async def automatic1111_generate(self, prompt: Style, format: str):
|
||||
url = self.api_url
|
||||
resolution = self.resolution_from_format(
|
||||
format, self.automatic1111_render_settings.type_model
|
||||
)
|
||||
render_settings = self.automatic1111_render_settings
|
||||
payload = {
|
||||
"prompt": prompt.positive_prompt,
|
||||
"negative_prompt": prompt.negative_prompt,
|
||||
"steps": render_settings.steps,
|
||||
"width": resolution.width,
|
||||
"height": resolution.height,
|
||||
}
|
||||
|
||||
log.info("automatic1111_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url=f"{url}/sdapi/v1/txt2img", json=payload, timeout=90
|
||||
)
|
||||
|
||||
r = response.json()
|
||||
|
||||
# image = Image.open(io.BytesIO(base64.b64decode(r['images'][0])))
|
||||
# image.save('a1111-test.png')
|
||||
|
||||
#'log.info("automatic1111_generate", saved_to="a1111-test.png")
|
||||
|
||||
for image in r["images"]:
|
||||
await self.emit_image(image)
|
||||
|
||||
async def automatic1111_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /sdapi/v1/memory and on 200 will return True
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=f"{self.api_url}/sdapi/v1/memory", timeout=2
|
||||
)
|
||||
return response.status_code == 200
|
||||
324
src/talemate/agents/visual/comfyui.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.comfyui")
|
||||
|
||||
|
||||
class Workflow(pydantic.BaseModel):
|
||||
nodes: dict
|
||||
|
||||
def set_resolution(self, resolution: Resolution):
|
||||
|
||||
# will collect all latent image nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Resolution"
|
||||
|
||||
# if there is no latent image node with the title "Talemate Resolution"
|
||||
# the first latent image node will be used
|
||||
|
||||
# resolution will be updated on the selected node
|
||||
|
||||
# if no latent image node is found a warning will be logged
|
||||
|
||||
latent_image_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
if node["class_type"] == "EmptyLatentImage":
|
||||
if not latent_image_node:
|
||||
latent_image_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Resolution":
|
||||
latent_image_node = node
|
||||
break
|
||||
|
||||
if not latent_image_node:
|
||||
log.warning("set_resolution", error="No latent image node found")
|
||||
return
|
||||
|
||||
latent_image_node["inputs"]["width"] = resolution.width
|
||||
latent_image_node["inputs"]["height"] = resolution.height
|
||||
|
||||
def set_prompt(self, prompt: str, negative_prompt: str = None):
|
||||
|
||||
# will collect all CLIPTextEncode nodes
|
||||
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Positive Prompt" and "Talemate Negative Prompt"
|
||||
#
|
||||
# if there is no CLIPTextEncode node with the title "Talemate Positive Prompt"
|
||||
# the first CLIPTextEncode node will be used
|
||||
#
|
||||
# if there is no CLIPTextEncode node with the title "Talemate Negative Prompt"
|
||||
# the second CLIPTextEncode node will be used
|
||||
#
|
||||
# prompt will be updated on the selected node
|
||||
|
||||
# if no CLIPTextEncode node is found an exception will be raised for
|
||||
# the positive prompt
|
||||
|
||||
# if no CLIPTextEncode node is found an exception will be raised for
|
||||
# the negative prompt if it is not None
|
||||
|
||||
positive_prompt_node = None
|
||||
negative_prompt_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
|
||||
if node["class_type"] == "CLIPTextEncode":
|
||||
if not positive_prompt_node:
|
||||
positive_prompt_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Positive Prompt":
|
||||
positive_prompt_node = node
|
||||
elif not negative_prompt_node:
|
||||
negative_prompt_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Negative Prompt":
|
||||
negative_prompt_node = node
|
||||
|
||||
if not positive_prompt_node:
|
||||
raise ValueError("No positive prompt node found")
|
||||
|
||||
positive_prompt_node["inputs"]["text"] = prompt
|
||||
|
||||
if negative_prompt and not negative_prompt_node:
|
||||
raise ValueError("No negative prompt node found")
|
||||
|
||||
if negative_prompt:
|
||||
negative_prompt_node["inputs"]["text"] = negative_prompt
|
||||
|
||||
def set_checkpoint(self, checkpoint: str):
|
||||
|
||||
# will collect all CheckpointLoaderSimple nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Load Checkpoint"
|
||||
|
||||
# if there is no CheckpointLoaderSimple node with the title "Talemate Load Checkpoint"
|
||||
# the first CheckpointLoaderSimple node will be used
|
||||
|
||||
# checkpoint will be updated on the selected node
|
||||
|
||||
# if no CheckpointLoaderSimple node is found a warning will be logged
|
||||
|
||||
checkpoint_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
if node["class_type"] == "CheckpointLoaderSimple":
|
||||
if not checkpoint_node:
|
||||
checkpoint_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Load Checkpoint":
|
||||
checkpoint_node = node
|
||||
break
|
||||
|
||||
if not checkpoint_node:
|
||||
log.warning("set_checkpoint", error="No checkpoint node found")
|
||||
return
|
||||
|
||||
checkpoint_node["inputs"]["ckpt_name"] = checkpoint
|
||||
|
||||
def set_seeds(self):
|
||||
for node in self.nodes.values():
|
||||
for field in node.get("inputs", {}).keys():
|
||||
if field == "noise_seed":
|
||||
node["inputs"]["noise_seed"] = random.randint(0, 999999999999999)
|
||||
|
||||
|
||||
@register(backend_name="comfyui", label="ComfyUI")
|
||||
class ComfyUIMixin:
|
||||
|
||||
comfyui_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"comfyui": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="comfyui"
|
||||
),
|
||||
label="ComfyUI Settings",
|
||||
description="Setting overrides for the comfyui backend",
|
||||
config={
|
||||
"api_url": AgentActionConfig(
|
||||
type="text",
|
||||
value="http://localhost:8188",
|
||||
label="API URL",
|
||||
description="The URL of the backend API",
|
||||
),
|
||||
"workflow": AgentActionConfig(
|
||||
type="text",
|
||||
value="default-sdxl.json",
|
||||
label="Workflow",
|
||||
description="The workflow to use for comfyui (workflow file name inside ./templates/comfyui-workflows)",
|
||||
),
|
||||
"checkpoint": AgentActionConfig(
|
||||
type="text",
|
||||
value="default",
|
||||
label="Checkpoint",
|
||||
choices=[],
|
||||
description="The main checkpoint to use.",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def comfyui_workflow_filename(self):
|
||||
base_name = self.actions["comfyui"].config["workflow"].value
|
||||
|
||||
# make absolute path
|
||||
abs_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"templates",
|
||||
"comfyui-workflows",
|
||||
base_name,
|
||||
)
|
||||
|
||||
return abs_path
|
||||
|
||||
@property
|
||||
def comfyui_workflow_is_sdxl(self) -> bool:
|
||||
"""
|
||||
Returns true if `sdxl` is in worhflow file name (case insensitive)
|
||||
"""
|
||||
|
||||
return "sdxl" in self.comfyui_workflow_filename.lower()
|
||||
|
||||
@property
|
||||
def comfyui_workflow(self) -> Workflow:
|
||||
workflow = self.comfyui_workflow_filename
|
||||
if not workflow:
|
||||
raise ValueError("No comfyui workflow file specified")
|
||||
|
||||
with open(workflow, "r") as f:
|
||||
return Workflow(nodes=json.load(f))
|
||||
|
||||
@property
|
||||
async def comfyui_object_info(self):
|
||||
if hasattr(self, "_comfyui_object_info"):
|
||||
return self._comfyui_object_info
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/object_info")
|
||||
self._comfyui_object_info = response.json()
|
||||
|
||||
return self._comfyui_object_info
|
||||
|
||||
@property
|
||||
async def comfyui_checkpoints(self):
|
||||
loader_node = (await self.comfyui_object_info)["CheckpointLoaderSimple"]
|
||||
_checkpoints = loader_node["input"]["required"]["ckpt_name"][0]
|
||||
return [
|
||||
{"label": checkpoint, "value": checkpoint} for checkpoint in _checkpoints
|
||||
]
|
||||
|
||||
async def comfyui_get_image(self, filename: str, subfolder: str, folder_type: str):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/view?{url_values}")
|
||||
return response.content
|
||||
|
||||
async def comfyui_get_history(self, prompt_id: str):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/history/{prompt_id}")
|
||||
return response.json()
|
||||
|
||||
async def comfyui_get_images(self, prompt_id: str, max_wait: int = 60.0):
|
||||
output_images = {}
|
||||
history = {}
|
||||
|
||||
start = time.time()
|
||||
|
||||
while not history:
|
||||
log.info(
|
||||
"comfyui_get_images", waiting_for_history=True, prompt_id=prompt_id
|
||||
)
|
||||
history = await self.comfyui_get_history(prompt_id)
|
||||
await asyncio.sleep(1.0)
|
||||
if time.time() - start > max_wait:
|
||||
raise TimeoutError("Max wait time exceeded")
|
||||
|
||||
for node_id, node_output in history[prompt_id]["outputs"].items():
|
||||
if "images" in node_output:
|
||||
images_output = []
|
||||
for image in node_output["images"]:
|
||||
image_data = await self.comfyui_get_image(
|
||||
image["filename"], image["subfolder"], image["type"]
|
||||
)
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
return output_images
|
||||
|
||||
async def comfyui_generate(self, prompt: Style, format: str):
|
||||
url = self.api_url
|
||||
workflow = self.comfyui_workflow
|
||||
is_sdxl = self.comfyui_workflow_is_sdxl
|
||||
|
||||
resolution = self.resolution_from_format(format, "sdxl" if is_sdxl else "sd15")
|
||||
|
||||
workflow.set_resolution(resolution)
|
||||
workflow.set_prompt(prompt.positive_prompt, prompt.negative_prompt)
|
||||
workflow.set_seeds()
|
||||
workflow.set_checkpoint(self.actions["comfyui"].config["checkpoint"].value)
|
||||
|
||||
payload = {"prompt": workflow.model_dump().get("nodes")}
|
||||
|
||||
log.info("comfyui_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=f"{url}/prompt", json=payload, timeout=90)
|
||||
|
||||
log.info("comfyui_generate", response=response.text)
|
||||
|
||||
r = response.json()
|
||||
|
||||
prompt_id = r["prompt_id"]
|
||||
|
||||
images = await self.comfyui_get_images(prompt_id)
|
||||
for node_id, node_images in images.items():
|
||||
for i, image in enumerate(node_images):
|
||||
await self.emit_image(base64.b64encode(image).decode("utf-8"))
|
||||
# image = Image.open(io.BytesIO(image))
|
||||
# image.save(f'comfyui-test.png')
|
||||
|
||||
async def comfyui_apply_config(
|
||||
self, backend_changed: bool = False, *args, **kwargs
|
||||
):
|
||||
log.debug(
|
||||
"comfyui_apply_config",
|
||||
backend_changed=backend_changed,
|
||||
enabled=self.enabled,
|
||||
)
|
||||
if (not self.initialized or backend_changed) and self.enabled:
|
||||
checkpoints = await self.comfyui_checkpoints
|
||||
selected_checkpoint = self.actions["comfyui"].config["checkpoint"].value
|
||||
self.actions["comfyui"].config["checkpoint"].choices = checkpoints
|
||||
self.actions["comfyui"].config["checkpoint"].value = selected_checkpoint
|
||||
|
||||
async def comfyui_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /system_stats and on 200 will return True
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/system_stats", timeout=2)
|
||||
return response.status_code == 200
|
||||
68
src/talemate/agents/visual/commands.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from talemate.agents.visual.context import VIS_TYPES, VisualContext
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.instance import get_agent
|
||||
|
||||
__all__ = [
|
||||
"CmdVisualizeTestGenerate",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeTestGenerate(TalemateCommand):
|
||||
"""
|
||||
Generates a visual test
|
||||
"""
|
||||
|
||||
name = "visual_test_generate"
|
||||
description = "Will generate a visual test"
|
||||
aliases = ["vis_test", "vtg"]
|
||||
|
||||
label = "Visualize test"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
prompt = self.args[0]
|
||||
with VisualContext(vis_type=VIS_TYPES.UNSPECIFIED):
|
||||
await visual.generate(prompt)
|
||||
return True
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeEnvironment(TalemateCommand):
|
||||
"""
|
||||
Shows the environment
|
||||
"""
|
||||
|
||||
name = "visual_environment"
|
||||
description = "Will show the environment"
|
||||
aliases = ["vis_env"]
|
||||
|
||||
label = "Visualize environment"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
await visual.generate_environment_background(
|
||||
instructions=self.args[0] if len(self.args) > 0 else None
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeCharacter(TalemateCommand):
|
||||
"""
|
||||
Shows a character
|
||||
"""
|
||||
|
||||
name = "visual_character"
|
||||
description = "Will show a character"
|
||||
aliases = ["vis_char"]
|
||||
|
||||
label = "Visualize character"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
character_name = self.args[0]
|
||||
instructions = self.args[1] if len(self.args) > 1 else None
|
||||
await visual.generate_character_portrait(character_name, instructions)
|
||||
return True
|
||||
55
src/talemate/agents/visual/context.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import contextvars
|
||||
import enum
|
||||
from typing import Union
|
||||
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"VIS_TYPES",
|
||||
"visual_context",
|
||||
"VisualContext",
|
||||
]
|
||||
|
||||
|
||||
class VIS_TYPES(str, enum.Enum):
|
||||
UNSPECIFIED = "UNSPECIFIED"
|
||||
ENVIRONMENT = "ENVIRONMENT"
|
||||
CHARACTER = "CHARACTER"
|
||||
ITEM = "ITEM"
|
||||
|
||||
|
||||
visual_context = contextvars.ContextVar("visual_context", default=None)
|
||||
|
||||
|
||||
class VisualContextState(pydantic.BaseModel):
|
||||
character_name: Union[str, None] = None
|
||||
instructions: Union[str, None] = None
|
||||
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT
|
||||
prompt: Union[str, None] = None
|
||||
prepared_prompt: Union[str, None] = None
|
||||
format: Union[str, None] = None
|
||||
|
||||
|
||||
class VisualContext:
|
||||
def __init__(
|
||||
self,
|
||||
character_name: Union[str, None] = None,
|
||||
instructions: Union[str, None] = None,
|
||||
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT,
|
||||
prompt: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.state = VisualContextState(
|
||||
character_name=character_name,
|
||||
instructions=instructions,
|
||||
vis_type=vis_type,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.token = visual_context.set(self.state)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
visual_context.reset(self.token)
|
||||
return False
|
||||
17
src/talemate/agents/visual/handlers.py
Normal file
@@ -0,0 +1,17 @@
|
||||
__all__ = [
|
||||
"HANDLERS",
|
||||
"register",
|
||||
]
|
||||
|
||||
HANDLERS = {}
|
||||
|
||||
|
||||
class register:
|
||||
|
||||
def __init__(self, backend_name: str, label: str):
|
||||
self.backend_name = backend_name
|
||||
self.label = label
|
||||
|
||||
def __call__(self, mixin_cls):
|
||||
HANDLERS[self.backend_name] = {"label": self.label, "cls": mixin_cls}
|
||||
return mixin_cls
|
||||
125
src/talemate/agents/visual/openai_image.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import base64
|
||||
import io
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.openai_image")
|
||||
|
||||
|
||||
@register(backend_name="openai_image", label="OpenAI")
|
||||
class OpenAIImageMixin:
|
||||
|
||||
openai_image_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"openai_image": AgentAction(
|
||||
enabled=False,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="openai_image"
|
||||
),
|
||||
label="OpenAI Image Generation Advanced Settings",
|
||||
description="Setting overrides for the openai backend",
|
||||
config={
|
||||
"model_type": AgentActionConfig(
|
||||
type="text",
|
||||
value="dall-e-3",
|
||||
choices=[
|
||||
{"value": "dall-e-3", "label": "DALL-E 3"},
|
||||
{"value": "dall-e-2", "label": "DALL-E 2"},
|
||||
],
|
||||
label="Model Type",
|
||||
description="Image generation model",
|
||||
),
|
||||
"quality": AgentActionConfig(
|
||||
type="text",
|
||||
value="standard",
|
||||
choices=[
|
||||
{"value": "standard", "label": "Standard"},
|
||||
{"value": "hd", "label": "HD"},
|
||||
],
|
||||
label="Quality",
|
||||
description="Image generation quality",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def openai_model_type(self):
|
||||
return self.actions["openai_image"].config["model_type"].value
|
||||
|
||||
@property
|
||||
def openai_quality(self):
|
||||
return self.actions["openai_image"].config["quality"].value
|
||||
|
||||
async def openai_image_generate(self, prompt: Style, format: str):
|
||||
"""
|
||||
#
|
||||
from openai import OpenAI
|
||||
client = OpenAI()
|
||||
|
||||
response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt="a white siamese cat",
|
||||
size="1024x1024",
|
||||
quality="standard",
|
||||
n=1,
|
||||
)
|
||||
|
||||
image_url = response.data[0].url
|
||||
"""
|
||||
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
# When using DALL·E 3, images can have a size of 1024x1024, 1024x1792 or 1792x1024 pixels.#
|
||||
|
||||
if format == "portrait":
|
||||
resolution = Resolution(width=1024, height=1792)
|
||||
elif format == "landscape":
|
||||
resolution = Resolution(width=1792, height=1024)
|
||||
else:
|
||||
resolution = Resolution(width=1024, height=1024)
|
||||
|
||||
log.debug("openai_image_generate", resolution=resolution)
|
||||
|
||||
response = await client.images.generate(
|
||||
model=self.openai_model_type,
|
||||
prompt=prompt.positive_prompt,
|
||||
size=f"{resolution.width}x{resolution.height}",
|
||||
quality=self.openai_quality,
|
||||
n=1,
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
await self.emit_image(response.data[0].b64_json)
|
||||
|
||||
async def openai_image_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /sdapi/v1/memory and on 200 will return True
|
||||
"""
|
||||
|
||||
if not self.openai_api_key:
|
||||
raise ValueError("OpenAI API Key not set")
|
||||
|
||||
return True
|
||||
32
src/talemate/agents/visual/schema.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"RenderSettings",
|
||||
"Resolution",
|
||||
"RESOLUTION_MAP",
|
||||
]
|
||||
|
||||
RESOLUTION_MAP = {}
|
||||
|
||||
|
||||
class RenderSettings(pydantic.BaseModel):
|
||||
type_model: str = "sdxl"
|
||||
steps: int = 40
|
||||
|
||||
|
||||
class Resolution(pydantic.BaseModel):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
RESOLUTION_MAP["sdxl"] = {
|
||||
"portrait": Resolution(width=832, height=1216),
|
||||
"landscape": Resolution(width=1216, height=832),
|
||||
"square": Resolution(width=1024, height=1024),
|
||||
}
|
||||
|
||||
RESOLUTION_MAP["sd15"] = {
|
||||
"portrait": Resolution(width=512, height=768),
|
||||
"landscape": Resolution(width=768, height=512),
|
||||
"square": Resolution(width=768, height=768),
|
||||
}
|
||||
136
src/talemate/agents/visual/style.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
__all__ = [
|
||||
"Style",
|
||||
"STYLE_MAP",
|
||||
"THEME_MAP",
|
||||
"MAJOR_STYLES",
|
||||
"combine_styles",
|
||||
]
|
||||
|
||||
STYLE_MAP = {}
|
||||
THEME_MAP = {}
|
||||
MAJOR_STYLES = {}
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.style")
|
||||
|
||||
|
||||
class Style(pydantic.BaseModel):
|
||||
keywords: list[str] = pydantic.Field(default_factory=list)
|
||||
negative_keywords: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def positive_prompt(self):
|
||||
return ", ".join(self.keywords)
|
||||
|
||||
@property
|
||||
def negative_prompt(self):
|
||||
return ", ".join(self.negative_keywords)
|
||||
|
||||
def __str__(self):
|
||||
return f"POSITIVE: {self.positive_prompt}\nNEGATIVE: {self.negative_prompt}"
|
||||
|
||||
def load(self, prompt: str, negative_prompt: str = ""):
|
||||
self.keywords = prompt.split(", ")
|
||||
self.negative_keywords = negative_prompt.split(", ")
|
||||
|
||||
# loop through keywords and drop any starting with "no " and add to negative_keywords
|
||||
# with "no " removed
|
||||
for kw in self.keywords:
|
||||
kw = kw.strip()
|
||||
log.debug("Checking keyword", keyword=kw)
|
||||
if kw.startswith("no "):
|
||||
log.debug("Transforming negative keyword", keyword=kw, to=kw[3:])
|
||||
self.keywords.remove(kw)
|
||||
self.negative_keywords.append(kw[3:])
|
||||
|
||||
return self
|
||||
|
||||
def prepend(self, *styles):
|
||||
for style in styles:
|
||||
for idx in range(len(style.keywords) - 1, -1, -1):
|
||||
kw = style.keywords[idx]
|
||||
if kw not in self.keywords:
|
||||
self.keywords.insert(0, kw)
|
||||
|
||||
for idx in range(len(style.negative_keywords) - 1, -1, -1):
|
||||
kw = style.negative_keywords[idx]
|
||||
if kw not in self.negative_keywords:
|
||||
self.negative_keywords.insert(0, kw)
|
||||
|
||||
return self
|
||||
|
||||
def append(self, *styles):
|
||||
for style in styles:
|
||||
for kw in style.keywords:
|
||||
if kw not in self.keywords:
|
||||
self.keywords.append(kw)
|
||||
|
||||
for kw in style.negative_keywords:
|
||||
if kw not in self.negative_keywords:
|
||||
self.negative_keywords.append(kw)
|
||||
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
return Style(
|
||||
keywords=self.keywords.copy(),
|
||||
negative_keywords=self.negative_keywords.copy(),
|
||||
)
|
||||
|
||||
|
||||
# Almost taken straight from some of the fooocus style presets, credit goes to the original author
|
||||
|
||||
STYLE_MAP["digital_art"] = Style(
|
||||
keywords="digital artwork, masterpiece, best quality, high detail".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["concept_art"] = Style(
|
||||
keywords="concept art, conceptual sketch, masterpiece, best quality, high detail".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["ink_illustration"] = Style(
|
||||
keywords="ink illustration, painting, masterpiece, best quality".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["anime"] = Style(
|
||||
keywords="anime, masterpiece, best quality, illustration".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo, 3d".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["graphic_novel"] = Style(
|
||||
keywords="(stylized by Enki Bilal:0.7), best quality, graphic novels, detailed linework, digital art".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo, 3d, cgi".split(
|
||||
", "
|
||||
),
|
||||
)
|
||||
|
||||
STYLE_MAP["character_portrait"] = Style(keywords="solo, looking at viewer".split(", "))
|
||||
|
||||
STYLE_MAP["environment"] = Style(
|
||||
keywords="scenery, environment, background, postcard".split(", "),
|
||||
negative_keywords="character, portrait, looking at viewer, people".split(", "),
|
||||
)
|
||||
|
||||
MAJOR_STYLES = [
|
||||
{"value": "digital_art", "label": "Digital Art"},
|
||||
{"value": "concept_art", "label": "Concept Art"},
|
||||
{"value": "ink_illustration", "label": "Ink Illustration"},
|
||||
{"value": "anime", "label": "Anime"},
|
||||
{"value": "graphic_novel", "label": "Graphic Novel"},
|
||||
]
|
||||
|
||||
|
||||
def combine_styles(*styles):
|
||||
keywords = []
|
||||
for style in styles:
|
||||
keywords.extend(style.keywords)
|
||||
return Style(keywords=list(set(keywords)))
|
||||
84
src/talemate/agents/visual/websocket_handler.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
|
||||
from .context import VisualContext, VisualContextState
|
||||
|
||||
__all__ = [
|
||||
"VisualWebsocketHandler",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.server.visual")
|
||||
|
||||
|
||||
class SetCoverImagePayload(pydantic.BaseModel):
|
||||
base64: str
|
||||
context: Union[VisualContextState, None] = None
|
||||
|
||||
|
||||
class RegeneratePayload(pydantic.BaseModel):
|
||||
context: Union[VisualContextState, None] = None
|
||||
|
||||
|
||||
class VisualWebsocketHandler(Plugin):
|
||||
router = "visual"
|
||||
|
||||
async def handle_regenerate(self, data: dict):
|
||||
"""
|
||||
Regenerate the image based on the context.
|
||||
"""
|
||||
|
||||
payload = RegeneratePayload(**data)
|
||||
|
||||
context = payload.context
|
||||
|
||||
visual = get_agent("visual")
|
||||
|
||||
with VisualContext(**context.model_dump()):
|
||||
await visual.generate(format="")
|
||||
|
||||
async def handle_cover_image(self, data: dict):
|
||||
"""
|
||||
Sets the cover image for a character and the scene.
|
||||
"""
|
||||
|
||||
payload = SetCoverImagePayload(**data)
|
||||
|
||||
context = payload.context
|
||||
scene = self.scene
|
||||
|
||||
if context and context.character_name:
|
||||
|
||||
character = scene.get_character(context.character_name)
|
||||
|
||||
if not character:
|
||||
log.error("character not found", character_name=context.character_name)
|
||||
return
|
||||
|
||||
asset = scene.assets.add_asset_from_image_data(payload.base64)
|
||||
|
||||
log.info("setting scene cover image", character_name=context.character_name)
|
||||
scene.assets.cover_image = asset.id
|
||||
|
||||
log.info(
|
||||
"setting character cover image", character_name=context.character_name
|
||||
)
|
||||
character.cover_image = asset.id
|
||||
|
||||
scene.emit_status()
|
||||
self.websocket_handler.request_scene_assets([asset.id])
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_asset_character_cover_image",
|
||||
"asset_id": asset.id,
|
||||
"asset": self.scene.assets.get_asset_bytes_as_base64(asset.id),
|
||||
"media_type": asset.media_type,
|
||||
"character": character.name,
|
||||
}
|
||||
)
|
||||
return
|
||||
@@ -187,7 +187,7 @@ class WorldStateAgent(Agent):
|
||||
|
||||
await self.check_pin_conditions()
|
||||
|
||||
async def update_world_state(self):
|
||||
async def update_world_state(self, force: bool = False):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
@@ -206,13 +206,15 @@ class WorldStateAgent(Agent):
|
||||
self.next_update % self.actions["update_world_state"].config["turns"].value
|
||||
!= 0
|
||||
or self.next_update == 0
|
||||
):
|
||||
) and not force:
|
||||
self.next_update += 1
|
||||
return
|
||||
|
||||
self.next_update = 0
|
||||
await scene.world_state.request_update()
|
||||
|
||||
update_world_state.exposed = True
|
||||
|
||||
@set_processing
|
||||
async def request_world_state(self):
|
||||
t1 = time.time()
|
||||
@@ -349,11 +351,15 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
text: str,
|
||||
instruction: str,
|
||||
short: bool = False,
|
||||
):
|
||||
|
||||
kind = "analyze_freeform_short" if short else "analyze_freeform"
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-follow-instruction",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -376,11 +382,13 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
text: str,
|
||||
query: str,
|
||||
short: bool = False,
|
||||
):
|
||||
kind = "analyze_freeform_short" if short else "analyze_freeform"
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-answer-question",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -439,6 +447,7 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
name: str,
|
||||
text: str = None,
|
||||
alteration_instructions: str = None,
|
||||
):
|
||||
"""
|
||||
Attempts to extract a character sheet from the given text.
|
||||
@@ -453,6 +462,8 @@ class WorldStateAgent(Agent):
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"name": name,
|
||||
"character": self.scene.get_character(name),
|
||||
"alteration_instructions": alteration_instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -518,23 +529,37 @@ class WorldStateAgent(Agent):
|
||||
if reset and reinforcement.insert == "sequential":
|
||||
self.scene.pop_history(typ="reinforcement", source=source, all=True)
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
kind = "analyze_freeform_medium_short"
|
||||
else:
|
||||
kind = "analyze_freeform"
|
||||
|
||||
answer = await Prompt.request(
|
||||
"world_state.update-reinforcements",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"question": reinforcement.question,
|
||||
"instructions": reinforcement.instructions or "",
|
||||
"character": self.scene.get_character(reinforcement.character)
|
||||
if reinforcement.character
|
||||
else None,
|
||||
"character": (
|
||||
self.scene.get_character(reinforcement.character)
|
||||
if reinforcement.character
|
||||
else None
|
||||
),
|
||||
"answer": (reinforcement.answer if not reset else None) or "",
|
||||
"reinforcement": reinforcement,
|
||||
},
|
||||
)
|
||||
|
||||
# sequential reinforcment should be single sentence so we
|
||||
# split on line breaks and take the first line in case the
|
||||
# LLM did not understand the request and returned a longer response
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
answer = answer.split("\n")[0]
|
||||
|
||||
reinforcement.answer = answer
|
||||
reinforcement.due = reinforcement.interval
|
||||
|
||||
@@ -735,3 +760,28 @@ class WorldStateAgent(Agent):
|
||||
)
|
||||
|
||||
return is_leaving.lower().startswith("y")
|
||||
|
||||
@set_processing
|
||||
async def manager(self, action_name: str, *args, **kwargs):
|
||||
"""
|
||||
Executes a world state manager action through self.scene.world_state_manager
|
||||
"""
|
||||
|
||||
manager = self.scene.world_state_manager
|
||||
|
||||
try:
|
||||
fn = getattr(manager, action_name, None)
|
||||
|
||||
if not fn:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
return await fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"worldstate.manager",
|
||||
action_name=action_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
error=e,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import os
|
||||
|
||||
import talemate.client.runpod
|
||||
from talemate.client.anthropic import AnthropicClient
|
||||
from talemate.client.cohere import CohereClient
|
||||
from talemate.client.google import GoogleClient
|
||||
from talemate.client.groq import GroqClient
|
||||
from talemate.client.koboldccp import KoboldCppClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
225
src/talemate/client/anthropic.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from anthropic import AsyncAnthropic, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"AnthropicClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "claude-3-sonnet-20240229"
|
||||
|
||||
|
||||
@register()
|
||||
class AnthropicClient(ClientBase):
|
||||
"""
|
||||
Anthropic client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "anthropic"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Anthropic"
|
||||
title: str = "Anthropic"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="claude-3-sonnet-20240229", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def anthropic_api_key(self):
|
||||
return self.config.get("anthropic", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.anthropic_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"anthropic_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.anthropic_api_key:
|
||||
self.client = AsyncAnthropic(api_key="sk-1111")
|
||||
log.error("No anthropic API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "claude-3-opus-20240229"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncAnthropic(api_key=self.anthropic_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"anthropic set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.output_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.input_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.anthropic_api_key:
|
||||
raise Exception("No anthropic API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system_message,
|
||||
messages=[human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
log.debug("generated response", response=response.content)
|
||||
|
||||
response = response.content[0].text
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="anthropic API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
A unified client base, based on the openai API
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
@@ -8,6 +10,7 @@ from typing import Callable, Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import urllib3
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
import talemate.client.presets as presets
|
||||
@@ -22,14 +25,24 @@ from talemate.emit import emit
|
||||
# Set up logging level for httpx to WARNING to suppress debug logs.
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
REMOTE_SERVICES = [
|
||||
# TODO: runpod.py should add this to the list
|
||||
".runpod.net"
|
||||
]
|
||||
log = structlog.get_logger("client.base")
|
||||
|
||||
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
|
||||
|
||||
|
||||
class PromptData(pydantic.BaseModel):
|
||||
kind: str
|
||||
prompt: str
|
||||
response: str
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
client_name: str
|
||||
client_type: str
|
||||
time: Union[float, int]
|
||||
agent_stack: list[str] = pydantic.Field(default_factory=list)
|
||||
generation_parameters: dict = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class ErrorAction(pydantic.BaseModel):
|
||||
title: str
|
||||
action_name: str
|
||||
@@ -39,7 +52,16 @@ class ErrorAction(pydantic.BaseModel):
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:5000"
|
||||
max_token_length: int = 4096
|
||||
max_token_length: int = 8192
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
class ExtraField(pydantic.BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
label: str
|
||||
required: bool
|
||||
description: str
|
||||
|
||||
|
||||
class ClientBase:
|
||||
@@ -49,12 +71,15 @@ class ClientBase:
|
||||
name: str = None
|
||||
enabled: bool = True
|
||||
current_status: str = None
|
||||
max_token_length: int = 4096
|
||||
max_token_length: int = 8192
|
||||
processing: bool = False
|
||||
connected: bool = False
|
||||
conversation_retries: int = 2
|
||||
conversation_retries: int = 0
|
||||
auto_break_repetition_enabled: bool = True
|
||||
|
||||
decensor_enabled: bool = True
|
||||
auto_determine_prompt_template: bool = False
|
||||
finalizers: list[str] = []
|
||||
double_coercion: Union[str, None] = None
|
||||
client_type = "base"
|
||||
|
||||
class Meta(pydantic.BaseModel):
|
||||
@@ -73,9 +98,13 @@ class ClientBase:
|
||||
):
|
||||
self.api_url = api_url
|
||||
self.name = name or self.client_type
|
||||
self.auto_determine_prompt_template_attempt = None
|
||||
self.log = structlog.get_logger(f"client.{self.client_type}")
|
||||
self.double_coercion = kwargs.get("double_coercion", None)
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
|
||||
)
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def __str__(self):
|
||||
@@ -85,10 +114,18 @@ class ClientBase:
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_be_coerced(self):
|
||||
"""
|
||||
Determines whether or not his client can pass LLM coercion. (e.g., is able
|
||||
to predefine partial LLM output in the prompt)
|
||||
"""
|
||||
return self.Meta().requires_prompt_template
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
|
||||
|
||||
def prompt_template(self, sys_msg, prompt):
|
||||
def prompt_template(self, sys_msg: str, prompt: str):
|
||||
"""
|
||||
Applies the appropriate prompt template for the model.
|
||||
"""
|
||||
@@ -97,12 +134,24 @@ class ClientBase:
|
||||
self.log.warning("prompt template not applied", reason="no model loaded")
|
||||
return f"{sys_msg}\n{prompt}"
|
||||
|
||||
return model_prompt(self.model_name, sys_msg, prompt)[0]
|
||||
# is JSON coercion active?
|
||||
# Check for <|BOT|>{ in the prompt
|
||||
json_coercion = "<|BOT|>{" in prompt
|
||||
|
||||
if self.can_be_coerced and self.double_coercion and not json_coercion:
|
||||
double_coercion = self.double_coercion
|
||||
double_coercion = f"{double_coercion}\n\n"
|
||||
else:
|
||||
double_coercion = None
|
||||
|
||||
return model_prompt(self.model_name, sys_msg, prompt, double_coercion)[0]
|
||||
|
||||
def prompt_template_example(self):
|
||||
if not getattr(self, "model_name", None):
|
||||
return None, None
|
||||
return model_prompt(self.model_name, "sysmsg", "prompt<|BOT|>{LLM coercion}")
|
||||
return model_prompt(
|
||||
self.model_name, "{sysmsg}", "{prompt}<|BOT|>{LLM coercion}"
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
"""
|
||||
@@ -119,26 +168,54 @@ class ClientBase:
|
||||
self.api_url = kwargs["api_url"]
|
||||
|
||||
if kwargs.get("max_token_length"):
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.max_token_length = int(kwargs["max_token_length"])
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
def host_is_remote(self, url: str) -> bool:
|
||||
"""
|
||||
Returns whether or not the host is a remote service.
|
||||
|
||||
It checks common local hostnames / ip prefixes.
|
||||
|
||||
- localhost
|
||||
"""
|
||||
|
||||
host = urllib3.util.parse_url(url).host
|
||||
|
||||
if host.lower() == "localhost":
|
||||
return False
|
||||
|
||||
# use ipaddress module to check for local ip prefixes
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
if ip.is_loopback or ip.is_private:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def toggle_disabled_if_remote(self):
|
||||
"""
|
||||
If the client is targeting a remote recognized service, this
|
||||
will disable the client.
|
||||
"""
|
||||
|
||||
for service in REMOTE_SERVICES:
|
||||
if service in self.api_url:
|
||||
if self.enabled:
|
||||
self.log.warn(
|
||||
"remote service unreachable, disabling client", client=self.name
|
||||
)
|
||||
self.enabled = False
|
||||
if not self.api_url:
|
||||
return False
|
||||
|
||||
return True
|
||||
if self.host_is_remote(self.api_url) and self.enabled:
|
||||
self.log.warn(
|
||||
"remote service unreachable, disabling client", client=self.name
|
||||
)
|
||||
self.enabled = False
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -151,32 +228,71 @@ class ClientBase:
|
||||
- kind: the kind of generation
|
||||
"""
|
||||
|
||||
# TODO: make extensible
|
||||
if self.decensor_enabled:
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE
|
||||
if "analyze_freeform" in kind:
|
||||
return system_prompts.ANALYST_FREEFORM
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "basic" in kind:
|
||||
return system_prompts.BASIC
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR
|
||||
if "edit" in kind:
|
||||
return system_prompts.EDITOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE
|
||||
if "analyze_freeform" in kind:
|
||||
return system_prompts.ANALYST_FREEFORM
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE
|
||||
if "visualize" in kind:
|
||||
return system_prompts.VISUALIZE
|
||||
|
||||
else:
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR_NO_DECENSOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR_NO_DECENSOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR_NO_DECENSOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR_NO_DECENSOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY_NO_DECENSOR
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY_NO_DECENSOR
|
||||
if "basic" in kind:
|
||||
return system_prompts.BASIC
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR_NO_DECENSOR
|
||||
if "edit" in kind:
|
||||
return system_prompts.EDITOR_NO_DECENSOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE_NO_DECENSOR
|
||||
if "analyze_freeform" in kind:
|
||||
return system_prompts.ANALYST_FREEFORM_NO_DECENSOR
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST_NO_DECENSOR
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST_NO_DECENSOR
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE_NO_DECENSOR
|
||||
if "visualize" in kind:
|
||||
return system_prompts.VISUALIZE_NO_DECENSOR
|
||||
|
||||
return system_prompts.BASIC
|
||||
|
||||
@@ -205,6 +321,38 @@ class ClientBase:
|
||||
self.current_status = status
|
||||
|
||||
prompt_template_example, prompt_template_file = self.prompt_template_example()
|
||||
has_prompt_template = (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
)
|
||||
|
||||
if not has_prompt_template and self.auto_determine_prompt_template:
|
||||
|
||||
# only attempt to determine the prompt template once per model and
|
||||
# only if the model does not already have a prompt template
|
||||
|
||||
if self.auto_determine_prompt_template_attempt != self.model_name:
|
||||
log.info("auto_determine_prompt_template", model_name=self.model_name)
|
||||
self.auto_determine_prompt_template_attempt = self.model_name
|
||||
self.determine_prompt_template()
|
||||
prompt_template_example, prompt_template_file = (
|
||||
self.prompt_template_example()
|
||||
)
|
||||
has_prompt_template = (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
)
|
||||
|
||||
data = {
|
||||
"api_key": self.api_key,
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": has_prompt_template,
|
||||
"template_file": prompt_template_file,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"error_action": None,
|
||||
"double_coercion": self.double_coercion,
|
||||
}
|
||||
|
||||
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
|
||||
data[field_name] = getattr(self, field_name, None)
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
@@ -212,21 +360,29 @@ class ClientBase:
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"api_key": self.api_key,
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
),
|
||||
"template_file": prompt_template_file,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"error_action": None,
|
||||
},
|
||||
data=data,
|
||||
)
|
||||
|
||||
if status_change:
|
||||
instance.emit_agent_status_by_client(self)
|
||||
|
||||
def populate_extra_fields(self, data: dict):
|
||||
"""
|
||||
Updates data with the extra fields from the client's Meta
|
||||
"""
|
||||
|
||||
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
|
||||
data[field_name] = getattr(self, field_name, None)
|
||||
|
||||
def determine_prompt_template(self):
|
||||
if not self.model_name:
|
||||
return
|
||||
|
||||
template = model_prompt.query_hf_for_prompt_template_suggestion(self.model_name)
|
||||
|
||||
if template:
|
||||
model_prompt.create_user_override(template, self.model_name)
|
||||
|
||||
async def get_model_name(self):
|
||||
models = await self.client.models.list()
|
||||
try:
|
||||
@@ -254,14 +410,12 @@ class ClientBase:
|
||||
self.log.warning("client status error", e=e, client=self.name)
|
||||
self.model_name = None
|
||||
self.connected = False
|
||||
self.toggle_disabled_if_remote()
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
self.connected = True
|
||||
|
||||
if not self.model_name or self.model_name == "None":
|
||||
self.log.warning("client model not loaded", client=self)
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
@@ -301,11 +455,27 @@ class ClientBase:
|
||||
f"{character}:" for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
dialog_stopping_strings += [
|
||||
f"{character.upper()}\n"
|
||||
for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
if "extra_stopping_strings" in parameters:
|
||||
parameters["extra_stopping_strings"] += dialog_stopping_strings
|
||||
else:
|
||||
parameters["extra_stopping_strings"] = dialog_stopping_strings
|
||||
|
||||
def finalize(self, parameters: dict, prompt: str):
|
||||
|
||||
prompt = util.replace_special_tokens(prompt)
|
||||
|
||||
for finalizer in self.finalizers:
|
||||
fn = getattr(self, finalizer, None)
|
||||
prompt, applied = fn(parameters, prompt)
|
||||
if applied:
|
||||
return prompt
|
||||
return prompt
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -343,6 +513,9 @@ class ClientBase:
|
||||
"""
|
||||
|
||||
try:
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
|
||||
self.emit_status(processing=True)
|
||||
await self.status()
|
||||
|
||||
@@ -351,6 +524,9 @@ class ClientBase:
|
||||
finalized_prompt = self.prompt_template(
|
||||
self.get_system_message(kind), prompt
|
||||
).strip(" ")
|
||||
|
||||
finalized_prompt = self.finalize(prompt_param, finalized_prompt)
|
||||
|
||||
prompt_param = finalize(prompt_param)
|
||||
|
||||
token_length = self.count_tokens(finalized_prompt)
|
||||
@@ -364,9 +540,8 @@ class ClientBase:
|
||||
max_token_length=self.max_token_length,
|
||||
parameters=prompt_param,
|
||||
)
|
||||
response = await self.generate(
|
||||
self.repetition_adjustment(finalized_prompt), prompt_param, kind
|
||||
)
|
||||
prompt_sent = self.repetition_adjustment(finalized_prompt)
|
||||
response = await self.generate(prompt_sent, prompt_param, kind)
|
||||
|
||||
response, finalized_prompt = await self.auto_break_repetition(
|
||||
finalized_prompt, prompt_param, response, kind, retries
|
||||
@@ -382,21 +557,30 @@ class ClientBase:
|
||||
response = response.split(stopping_string)[0]
|
||||
break
|
||||
|
||||
agent_context = active_agent.get()
|
||||
|
||||
emit(
|
||||
"prompt_sent",
|
||||
data={
|
||||
"kind": kind,
|
||||
"prompt": finalized_prompt,
|
||||
"response": response,
|
||||
"prompt_tokens": token_length,
|
||||
"response_tokens": self.count_tokens(response),
|
||||
"time": time_end - time_start,
|
||||
},
|
||||
data=PromptData(
|
||||
kind=kind,
|
||||
prompt=prompt_sent,
|
||||
response=response,
|
||||
prompt_tokens=self._returned_prompt_tokens or token_length,
|
||||
response_tokens=self._returned_response_tokens
|
||||
or self.count_tokens(response),
|
||||
agent_stack=agent_context.agent_stack if agent_context else [],
|
||||
client_name=self.name,
|
||||
client_type=self.client_type,
|
||||
time=time_end - time_start,
|
||||
generation_parameters=prompt_param,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
finally:
|
||||
self.emit_status(processing=False)
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
|
||||
async def auto_break_repetition(
|
||||
self,
|
||||
@@ -429,7 +613,7 @@ class ClientBase:
|
||||
- the response
|
||||
"""
|
||||
|
||||
if not self.auto_break_repetition_enabled:
|
||||
if not self.auto_break_repetition_enabled or not response.strip():
|
||||
return response, finalized_prompt
|
||||
|
||||
agent_context = active_agent.get()
|
||||
@@ -557,7 +741,6 @@ class ClientBase:
|
||||
|
||||
lines = prompt.split("\n")
|
||||
new_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("[$REPETITION|"):
|
||||
if is_repetitive:
|
||||
|
||||
229
src/talemate/client/cohere.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from cohere import AsyncClient
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
|
||||
__all__ = [
|
||||
"CohereClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"command",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "command-r-plus"
|
||||
|
||||
|
||||
@register()
|
||||
class CohereClient(ClientBase):
|
||||
"""
|
||||
Cohere client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "cohere"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Cohere"
|
||||
title: str = "Cohere"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="command-r-plus", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def cohere_api_key(self):
|
||||
return self.config.get("cohere", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.cohere_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"cohere_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.cohere_api_key:
|
||||
self.client = AsyncClient("sk-1111")
|
||||
log.error("No cohere API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "command-r-plus"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncClient(self.cohere_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"cohere set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return count_tokens(response.text)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
return count_tokens(prompt)
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
# if temperature is set, it needs to be clamped between 0 and 1.0
|
||||
if "temperature" in parameters:
|
||||
parameters["temperature"] = max(0.0, min(1.0, parameters["temperature"]))
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.cohere_api_key:
|
||||
raise Exception("No cohere API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat(
|
||||
model=self.model_name,
|
||||
preamble=system_message,
|
||||
message=human_message,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
log.debug("generated response", response=response.text)
|
||||
|
||||
response = response.text
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
# except PermissionDeniedError as e:
|
||||
# self.log.error("generate error", e=e)
|
||||
# emit("status", message="cohere API: Permission Denied", status="error")
|
||||
# return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
34
src/talemate/client/custom/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger("talemate.client.custom")
|
||||
|
||||
# import every submodule in this directory
|
||||
#
|
||||
# each directory in this directory is a submodule
|
||||
|
||||
# get the current directory
|
||||
current_directory = os.path.dirname(__file__)
|
||||
|
||||
# get all subdirectories
|
||||
subdirectories = [
|
||||
os.path.join(current_directory, name)
|
||||
for name in os.listdir(current_directory)
|
||||
if os.path.isdir(os.path.join(current_directory, name))
|
||||
]
|
||||
|
||||
# import every submodule
|
||||
|
||||
for subdirectory in subdirectories:
|
||||
# get the name of the submodule
|
||||
submodule_name = os.path.basename(subdirectory)
|
||||
|
||||
if submodule_name.startswith("__"):
|
||||
continue
|
||||
|
||||
log.info("activating custom client", module=submodule_name)
|
||||
|
||||
# import the submodule
|
||||
importlib.import_module(f".{submodule_name}", __package__)
|
||||
@@ -0,0 +1,5 @@
|
||||
Each client should be in its own subdirectory.
|
||||
|
||||
The subdirectory itself must be a valid python module.
|
||||
|
||||
Check out docs/dev/client/example/test for a very simplistic custom client example.
|
||||
312
src/talemate/client/google.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import vertexai
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
from vertexai.generative_models import (
|
||||
ChatSession,
|
||||
GenerativeModel,
|
||||
ResponseValidationError,
|
||||
SafetySetting,
|
||||
)
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.remote import RemoteServiceMixin
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.util import count_tokens
|
||||
|
||||
__all__ = [
|
||||
"GoogleClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"gemini-1.0-pro",
|
||||
"gemini-1.5-pro-preview-0409",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gemini-1.0-pro"
|
||||
disable_safety_settings: bool = False
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
disable_safety_settings: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class GoogleClient(RemoteServiceMixin, ClientBase):
|
||||
"""
|
||||
Google client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "google"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Google"
|
||||
title: str = "Google"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"disable_safety_settings": ExtraField(
|
||||
name="disable_safety_settings",
|
||||
type="bool",
|
||||
label="Disable Safety Settings",
|
||||
required=False,
|
||||
description="Disable Google's safety settings for responses generated by the model.",
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, model="gemini-1.0-pro", **kwargs):
|
||||
self.model_name = model
|
||||
self.setup_status = None
|
||||
self.model_instance = None
|
||||
self.disable_safety_settings = kwargs.get("disable_safety_settings", False)
|
||||
self.google_credentials_read = False
|
||||
self.google_project_id = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def google_credentials(self):
|
||||
path = self.google_credentials_path
|
||||
if not path:
|
||||
return None
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
@property
|
||||
def google_credentials_path(self):
|
||||
return self.config.get("google").get("gcloud_credentials_path")
|
||||
|
||||
@property
|
||||
def google_location(self):
|
||||
return self.config.get("google").get("gcloud_location")
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
# all google settings must be set
|
||||
return all(
|
||||
[
|
||||
self.google_credentials_path,
|
||||
self.google_location,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def safety_settings(self):
|
||||
if not self.disable_safety_settings:
|
||||
return None
|
||||
|
||||
safety_settings = [
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
SafetySetting(
|
||||
category=SafetySetting.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
|
||||
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE,
|
||||
),
|
||||
]
|
||||
|
||||
return safety_settings
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.ready:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "Setup incomplete"
|
||||
error_action = ErrorAction(
|
||||
title="Setup Google API credentials",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"google_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
}
|
||||
|
||||
self.populate_extra_fields(data)
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None, **kwargs):
|
||||
if not self.ready:
|
||||
log.error("Google cloud setup incomplete")
|
||||
if self.setup_status:
|
||||
self.setup_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "gemini-1.0-pro"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
if self.google_credentials_path:
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.google_credentials_path
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.setup_status:
|
||||
if self.setup_status is False:
|
||||
project_id = self.google_credentials.get("project_id")
|
||||
self.google_project_id = project_id
|
||||
if self.google_credentials_path:
|
||||
vertexai.init(project=project_id, location=self.google_location)
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.setup_status = True
|
||||
|
||||
self.model_instance = GenerativeModel(model_name=model)
|
||||
|
||||
log.info(
|
||||
"google set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return count_tokens(response.text)
|
||||
|
||||
def prompt_tokens(self, prompt: str):
|
||||
return count_tokens(prompt)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
if "disable_safety_settings" in kwargs:
|
||||
self.disable_safety_settings = kwargs["disable_safety_settings"]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.ready:
|
||||
raise Exception("Google cloud setup incomplete")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
disable_safety_settings=self.disable_safety_settings,
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
chat = self.model_instance.start_chat()
|
||||
|
||||
response = await chat.send_message_async(
|
||||
human_message,
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
response = response.text
|
||||
|
||||
log.debug("generated response", response=response)
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
|
||||
# except PermissionDeniedError as e:
|
||||
# self.log.error("generate error", e=e)
|
||||
# emit("status", message="google API: Permission Denied", status="error")
|
||||
# return ""
|
||||
except ResourceExhausted as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="google API: Quota Limit reached", status="error")
|
||||
return ""
|
||||
except ResponseValidationError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
"status",
|
||||
message="google API: Response Validation Error",
|
||||
status="error",
|
||||
)
|
||||
if not self.disable_safety_settings:
|
||||
return "Failed to generate response. Probably due to safety settings, you can turn them off in the client settings."
|
||||
return "Failed to generate response. Please check logs."
|
||||
except Exception as e:
|
||||
raise
|
||||
235
src/talemate/client/groq.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from groq import AsyncGroq, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"GroqClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"mixtral-8x7b-32768",
|
||||
"llama3-8b-8192",
|
||||
"llama3-70b-8192",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = []
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 8192
|
||||
model: str = "llama3-70b-8192"
|
||||
|
||||
|
||||
@register()
|
||||
class GroqClient(ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "groq"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Groq"
|
||||
title: str = "Groq"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="llama3-70b-8192", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def groq_api_key(self):
|
||||
return self.config.get("groq", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.groq_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"groq_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.groq_api_key:
|
||||
self.client = AsyncGroq(api_key="sk-1111")
|
||||
log.error("No groq.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "llama3-70b-8192"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncGroq(api_key=self.groq_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"groq.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.prompt_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.groq_api_key:
|
||||
raise Exception("No groq.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="OpenAI API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,16 +1,201 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Union
|
||||
import re
|
||||
|
||||
import requests
|
||||
# import urljoin
|
||||
from urllib.parse import urljoin
|
||||
import httpx
|
||||
import structlog
|
||||
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
import talemate.util as util
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.textgenwebui import RESTTaleMateClient
|
||||
from talemate.emit import Emission, emit
|
||||
|
||||
# NOT IMPLEMENTED AT THIS POINT
|
||||
log = structlog.get_logger("talemate.client.koboldcpp")
|
||||
|
||||
|
||||
class KoboldCppClientDefaults(Defaults):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
@register()
|
||||
class KoboldCppClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "koboldcpp"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "KoboldCpp"
|
||||
title: str = "KoboldCpp"
|
||||
enable_api_auth: bool = True
|
||||
defaults: KoboldCppClientDefaults = KoboldCppClientDefaults()
|
||||
|
||||
@property
|
||||
def request_headers(self):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
@property
|
||||
def is_openai(self) -> bool:
|
||||
"""
|
||||
kcpp has two apis
|
||||
|
||||
open-ai implementation at /v1
|
||||
their own implenation at /api/v1
|
||||
"""
|
||||
return "/api/v1" not in self.api_url
|
||||
|
||||
@property
|
||||
def api_url_for_model(self) -> str:
|
||||
if self.is_openai:
|
||||
# join /model to url
|
||||
return urljoin(self.api_url, "models")
|
||||
else:
|
||||
# join /models to url
|
||||
return urljoin(self.api_url, "model")
|
||||
|
||||
@property
|
||||
def api_url_for_generation(self) -> str:
|
||||
if self.is_openai:
|
||||
# join /v1/completions
|
||||
return urljoin(self.api_url, "completions")
|
||||
else:
|
||||
# join /api/v1/generate
|
||||
return urljoin(self.api_url, "generate")
|
||||
|
||||
def api_endpoint_specified(self, url: str) -> bool:
|
||||
return "/v1" in self.api_url
|
||||
|
||||
def ensure_api_endpoint_specified(self):
|
||||
if not self.api_endpoint_specified(self.api_url):
|
||||
# url doesn't specify the api endpoint
|
||||
# use the koboldcpp openai api
|
||||
self.api_url = urljoin(self.api_url.rstrip("/") + "/", "/api/v1/")
|
||||
if not self.api_url.endswith("/"):
|
||||
self.api_url += "/"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.api_key = kwargs.pop("api_key", "")
|
||||
super().__init__(**kwargs)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
if not self.is_openai:
|
||||
# adjustments for united api
|
||||
parameters["max_length"] = parameters.pop("max_tokens")
|
||||
parameters["max_context_length"] = self.max_token_length
|
||||
if "repetition_penalty_range" in parameters:
|
||||
parameters["rep_pen_range"] = parameters.pop("repetition_penalty_range")
|
||||
if "repetition_penalty" in parameters:
|
||||
parameters["rep_pen"] = parameters.pop("repetition_penalty")
|
||||
if parameters.get("stop_sequence"):
|
||||
parameters["stop_sequence"] = parameters.pop("stopping_strings")
|
||||
|
||||
if parameters.get("extra_stopping_strings"):
|
||||
if "stop_sequence" in parameters:
|
||||
parameters["stop_sequence"] += parameters.pop("extra_stopping_strings")
|
||||
else:
|
||||
parameters["stop_sequence"] = parameters.pop("extra_stopping_strings")
|
||||
|
||||
|
||||
allowed_params = [
|
||||
"max_length",
|
||||
"max_context_length",
|
||||
"rep_pen",
|
||||
"rep_pen_range",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"temperature",
|
||||
"stop_sequence",
|
||||
]
|
||||
else:
|
||||
# adjustments for openai api
|
||||
if "repetition_penalty" in parameters:
|
||||
parameters["presence_penalty"] = parameters.pop(
|
||||
"repetition_penalty"
|
||||
)
|
||||
|
||||
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
|
||||
|
||||
# drop unsupported params
|
||||
for param in list(parameters.keys()):
|
||||
if param not in allowed_params:
|
||||
del parameters[param]
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
async def get_model_name(self):
|
||||
self.ensure_api_endpoint_specified()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.api_url_for_model,
|
||||
timeout=2,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
|
||||
|
||||
response_data = response.json()
|
||||
if self.is_openai:
|
||||
# {"object": "list", "data": [{"id": "koboldcpp/dolphin-2.8-mistral-7b", "object": "model", "created": 1, "owned_by": "koboldcpp", "permission": [], "root": "koboldcpp"}]}
|
||||
model_name = response_data.get("data")[0].get("id")
|
||||
else:
|
||||
# {"result": "koboldcpp/dolphin-2.8-mistral-7b"}
|
||||
model_name = response_data.get("result")
|
||||
|
||||
# split by "/" and take last
|
||||
if model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.api_url_for_generation,
|
||||
json=parameters,
|
||||
timeout=None,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
response_data = response.json()
|
||||
|
||||
try:
|
||||
if self.is_openai:
|
||||
return response_data["choices"][0]["text"]
|
||||
else:
|
||||
return response_data["results"][0]["text"]
|
||||
except (TypeError, KeyError) as exc:
|
||||
log.error("Failed to generate text", exc=exc, response_data=response_data, response_status=response.status_code)
|
||||
return ""
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config["rep_pen"]
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
prompt_config["rep_pen"] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "api_key" in kwargs:
|
||||
self.api_key = kwargs.pop("api_key")
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
@@ -7,10 +7,12 @@ from talemate.client.registry import register
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
max_token_length: int = 8192
|
||||
|
||||
|
||||
@register()
|
||||
class LMStudioClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "lmstudio"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
|
||||
254
src/talemate/client/mistral.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from mistralai.async_client import MistralAsyncClient
|
||||
from mistralai.exceptions import MistralAPIStatusException
|
||||
from mistralai.models.chat_completion import ChatMessage
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"MistralAIClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"open-mixtral-8x22b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = SUPPORTED_MODELS
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "open-mixtral-8x7b"
|
||||
|
||||
|
||||
@register()
|
||||
class MistralAIClient(ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "mistral"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = True
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "MistralAI"
|
||||
title: str = "MistralAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="open-mixtral-8x7b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def mistralai_api_key(self):
|
||||
return self.config.get("mistralai", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.mistralai_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"mistralai_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.mistralai_api_key:
|
||||
self.client = MistralAsyncClient(api_key="sk-1111")
|
||||
log.error("No mistral.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "open-mixtral-8x7b"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = MistralAsyncClient(api_key=self.mistralai_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"mistral.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.prompt_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
# clamp temperature to 0.1 and 1.0
|
||||
# Unhandled Error: Status: 422. Message: {"object":"error","message":{"detail":[{"type":"less_than_equal","loc":["body","temperature"],"msg":"Input should be less than or equal to 1","input":1.31,"ctx":{"le":1.0},"url":"https://errors.pydantic.dev/2.6/v/less_than_equal"}]},"type":"invalid_request_error","param":null,"code":null}
|
||||
|
||||
if "temperature" in parameters:
|
||||
parameters["temperature"] = min(1.0, max(0.1, parameters["temperature"]))
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.mistralai_api_key:
|
||||
raise Exception("No mistral.ai API key set")
|
||||
|
||||
supports_json_object = self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="system", content=system_message),
|
||||
ChatMessage(role="user", content=prompt.strip()),
|
||||
]
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except MistralAPIStatusException as e:
|
||||
self.log.error("generate error", e=e)
|
||||
if e.http_status in [403, 401]:
|
||||
emit(
|
||||
"status",
|
||||
message="mistral.ai API: Permission Denied",
|
||||
status="error",
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -38,7 +39,6 @@ log = structlog.get_logger("talemate.model_prompts")
|
||||
|
||||
|
||||
class ModelPrompt:
|
||||
|
||||
"""
|
||||
Will attempt to load an LLM prompt template based on the model name
|
||||
|
||||
@@ -67,14 +67,27 @@ class ModelPrompt:
|
||||
env = Environment(loader=FileSystemLoader(STD_TEMPLATE_PATH))
|
||||
return sorted(env.list_templates())
|
||||
|
||||
def __call__(self, model_name: str, system_message: str, prompt: str):
|
||||
def __call__(
|
||||
self,
|
||||
model_name: str,
|
||||
system_message: str,
|
||||
prompt: str,
|
||||
double_coercion: str = None,
|
||||
):
|
||||
template, template_file = self.get_template(model_name)
|
||||
if not template:
|
||||
template_file = "default.jinja2"
|
||||
template = self.env.get_template(template_file)
|
||||
|
||||
if not double_coercion:
|
||||
double_coercion = ""
|
||||
|
||||
if "<|BOT|>" not in prompt and double_coercion:
|
||||
prompt = f"{prompt}<|BOT|>"
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
user_message, coercion_message = prompt.split("<|BOT|>", 1)
|
||||
coercion_message = f"{double_coercion}{coercion_message}"
|
||||
else:
|
||||
user_message = prompt
|
||||
coercion_message = ""
|
||||
@@ -83,19 +96,30 @@ class ModelPrompt:
|
||||
template.render(
|
||||
{
|
||||
"system_message": system_message,
|
||||
"prompt": prompt,
|
||||
"user_message": user_message,
|
||||
"prompt": prompt.strip(),
|
||||
"user_message": user_message.strip(),
|
||||
"coercion_message": coercion_message,
|
||||
"set_response": self.set_response,
|
||||
"set_response": lambda prompt, response_str: self.set_response(
|
||||
prompt, response_str, double_coercion
|
||||
),
|
||||
}
|
||||
),
|
||||
template_file,
|
||||
)
|
||||
|
||||
def set_response(self, prompt: str, response_str: str):
|
||||
def set_response(self, prompt: str, response_str: str, double_coercion: str = None):
|
||||
prompt = prompt.strip("\n").strip()
|
||||
|
||||
if not double_coercion:
|
||||
double_coercion = ""
|
||||
|
||||
if "<|BOT|>" not in prompt and double_coercion:
|
||||
prompt = f"{prompt}<|BOT|>"
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
|
||||
response_str = f"{double_coercion}{response_str}"
|
||||
|
||||
if "\n<|BOT|>" in prompt:
|
||||
prompt = prompt.replace("\n<|BOT|>", response_str)
|
||||
else:
|
||||
@@ -156,11 +180,19 @@ class ModelPrompt:
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
models = list(
|
||||
api.list_models(
|
||||
filter=huggingface_hub.ModelFilter(model_name=model_name, author=author)
|
||||
)
|
||||
)
|
||||
branch_name = "main"
|
||||
|
||||
# special popular cases
|
||||
|
||||
# bartowski
|
||||
|
||||
if author == "bartowski" and "exl2" in model_name:
|
||||
# split model_name by exl2 and take the first part with "exl2" readded
|
||||
# the second part is the branch name
|
||||
model_name, branch_name = model_name.split("exl2_", 1)
|
||||
model_name = f"{model_name}exl2"
|
||||
|
||||
models = list(api.list_models(model_name=model_name, author=author))
|
||||
|
||||
if not models:
|
||||
return None
|
||||
@@ -168,9 +200,14 @@ class ModelPrompt:
|
||||
model = models[0]
|
||||
|
||||
repo_id = f"{author}/{model_name}"
|
||||
|
||||
# Check README.md
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
readme_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=repo_id, filename="README.md", cache_dir=tmpdir
|
||||
repo_id=repo_id,
|
||||
filename="README.md",
|
||||
cache_dir=tmpdir,
|
||||
revision=branch_name,
|
||||
)
|
||||
if not readme_path:
|
||||
return None
|
||||
@@ -181,6 +218,24 @@ class ModelPrompt:
|
||||
if identifier(readme):
|
||||
return f"{identifier.template_str}.jinja2"
|
||||
|
||||
# Check tokenizer_config.json
|
||||
# "chat_template" key
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="tokenizer_config.json",
|
||||
cache_dir=tmpdir,
|
||||
revision=branch_name,
|
||||
)
|
||||
if not config_path:
|
||||
return None
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
for identifer_cls in TEMPLATE_IDENTIFIERS:
|
||||
identifier = identifer_cls()
|
||||
if identifier(config.get("chat_template", "")):
|
||||
return f"{identifier.template_str}.jinja2"
|
||||
|
||||
|
||||
model_prompt = ModelPrompt()
|
||||
|
||||
@@ -198,6 +253,14 @@ class Llama2Identifier(TemplateIdentifier):
|
||||
return "[INST]" in content and "[/INST]" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class Llama3Identifier(TemplateIdentifier):
|
||||
template_str = "Llama3"
|
||||
|
||||
def __call__(self, content: str):
|
||||
return "<|start_header_id|>" in content and "<|end_header_id|>" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class ChatMLIdentifier(TemplateIdentifier):
|
||||
template_str = "ChatML"
|
||||
@@ -212,11 +275,42 @@ class ChatMLIdentifier(TemplateIdentifier):
|
||||
{{ coercion_message }}
|
||||
"""
|
||||
|
||||
return "<|im_start|>" in content and "<|im_end|>" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class CommandRIdentifier(TemplateIdentifier):
|
||||
template_str = "CommandR"
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ system_message }}
|
||||
{{ user_message }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|>
|
||||
<|CHATBOT_TOKEN|>{{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"<|im_start|>system" in content
|
||||
and "<|im_end|>" in content
|
||||
and "<|im_start|>user" in content
|
||||
and "<|im_start|>assistant" in content
|
||||
"<|START_OF_TURN_TOKEN|>" in content
|
||||
and "<|END_OF_TURN_TOKEN|>" in content
|
||||
and "<|SYSTEM_TOKEN|>" not in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class CommandRPlusIdentifier(TemplateIdentifier):
|
||||
template_str = "CommandRPlus"
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ system_message }}
|
||||
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ user_message }}
|
||||
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"<|START_OF_TURN_TOKEN|>" in content
|
||||
and "<|END_OF_TURN_TOKEN|>" in content
|
||||
and "<|SYSTEM_TOKEN|>" in content
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,27 @@ __all__ = [
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-3.5-turbo-0125",
|
||||
]
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
@@ -71,7 +92,7 @@ def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gpt-4-turbo-preview"
|
||||
model: str = "gpt-4-turbo"
|
||||
|
||||
|
||||
@register()
|
||||
@@ -83,23 +104,18 @@ class OpenAIClient(ClientBase):
|
||||
client_type = "openai"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "OpenAI"
|
||||
title: str = "OpenAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
]
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="gpt-4-turbo-preview", **kwargs):
|
||||
def __init__(self, model="gpt-4-turbo", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
@@ -163,6 +179,9 @@ class OpenAIClient(ClientBase):
|
||||
if not self.model_name:
|
||||
self.model_name = "gpt-3.5-turbo-16k"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
@@ -214,7 +233,7 @@ class OpenAIClient(ClientBase):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
@@ -227,6 +246,15 @@ class OpenAIClient(ClientBase):
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
# GPT-3.5 models tend to run away with the generated
|
||||
# response size so we allow talemate to set the max_tokens
|
||||
#
|
||||
# GPT-4 on the other hand seems to benefit from letting it
|
||||
# decide the generation length naturally and it will generally
|
||||
# produce reasonably sized responses
|
||||
if self.model_name.startswith("gpt-3.5-"):
|
||||
valid_keys.append("max_tokens")
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
@@ -240,10 +268,14 @@ class OpenAIClient(ClientBase):
|
||||
raise Exception("No OpenAI API key set")
|
||||
|
||||
# only gpt-4-* supports enforcing json object
|
||||
supports_json_object = self.model_name.startswith("gpt-4-")
|
||||
supports_json_object = (
|
||||
self.model_name.startswith("gpt-4-")
|
||||
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
)
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nContinue this response: ")
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
@@ -253,7 +285,12 @@ class OpenAIClient(ClientBase):
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
@@ -264,6 +301,17 @@ class OpenAIClient(ClientBase):
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
|
||||
@@ -1,24 +1,36 @@
|
||||
import urllib
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
from talemate.emit import emit
|
||||
|
||||
log = structlog.get_logger("talemate.client.openai_compat")
|
||||
|
||||
EXPERIMENTAL_DESCRIPTION = """Use this client if you want to connect to a service implementing an OpenAI-compatible API. Success is going to depend on the level of compatibility. Use the actual OpenAI client if you want to connect to OpenAI's API."""
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:5000"
|
||||
api_key: str = ""
|
||||
max_token_length: int = 4096
|
||||
max_token_length: int = 8192
|
||||
model: str = ""
|
||||
api_handles_prompt_template: bool = False
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
api_handles_prompt_template: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAICompatibleClient(ClientBase):
|
||||
client_type = "openai_compat"
|
||||
conversation_retries = 5
|
||||
conversation_retries = 0
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title: str = "OpenAI Compatible API"
|
||||
@@ -27,18 +39,43 @@ class OpenAICompatibleClient(ClientBase):
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"api_handles_prompt_template": ExtraField(
|
||||
name="api_handles_prompt_template",
|
||||
type="bool",
|
||||
label="API Handles Prompt Template",
|
||||
required=False,
|
||||
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored.",
|
||||
)
|
||||
}
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
def __init__(
|
||||
self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs
|
||||
):
|
||||
self.model_name = model
|
||||
self.api_key = api_key
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return EXPERIMENTAL_DESCRIPTION
|
||||
|
||||
@property
|
||||
def can_be_coerced(self):
|
||||
"""
|
||||
Determines whether or not his client can pass LLM coercion. (e.g., is able
|
||||
to predefine partial LLM output in the prompt)
|
||||
"""
|
||||
return not self.api_handles_prompt_template
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key")
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key=self.api_key)
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
url = self.api_url
|
||||
self.client = AsyncOpenAI(base_url=url, api_key=self.api_key)
|
||||
self.model_name = (
|
||||
kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
)
|
||||
@@ -48,32 +85,33 @@ class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
|
||||
log.debug(
|
||||
"IS API HANDLING PROMPT TEMPLATE",
|
||||
api_handles_prompt_template=self.api_handles_prompt_template,
|
||||
)
|
||||
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def get_model_name(self):
|
||||
try:
|
||||
model_name = await super().get_model_name()
|
||||
except NotFoundError as e:
|
||||
# api does not implement model listing
|
||||
return self.model_name
|
||||
except Exception as e:
|
||||
self.log.error("get_model_name error", e=e)
|
||||
return self.model_name
|
||||
|
||||
# model name may be a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
is_filepath = "/" in model_name
|
||||
is_filepath_windows = "\\" in model_name
|
||||
|
||||
if is_filepath or is_filepath_windows:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
return self.model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
@@ -106,8 +144,14 @@ class OpenAICompatibleClient(ClientBase):
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
|
||||
)
|
||||
if "api_key" in kwargs:
|
||||
self.api_auth = kwargs["api_key"]
|
||||
self.api_key = kwargs["api_key"]
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
@@ -34,6 +34,13 @@ PRESET_LLAMA_PRECISE = {
|
||||
"repetition_penalty": 1.18,
|
||||
}
|
||||
|
||||
PRESET_DETERMINISTIC = {
|
||||
"temperature": 0.1,
|
||||
"top_p": 1,
|
||||
"top_k": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
|
||||
PRESET_DIVINE_INTELLECT = {
|
||||
"temperature": 1.31,
|
||||
"top_p": 0.14,
|
||||
@@ -49,6 +56,12 @@ PRESET_SIMPLE_1 = {
|
||||
"repetition_penalty": 1.15,
|
||||
}
|
||||
|
||||
PRESET_ANALYTICAL = {
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.9,
|
||||
"top_k": 20,
|
||||
}
|
||||
|
||||
|
||||
def configure(config: dict, kind: str, total_budget: int):
|
||||
"""
|
||||
@@ -75,7 +88,17 @@ def set_preset(config: dict, kind: str):
|
||||
|
||||
|
||||
def preset_for_kind(kind: str):
|
||||
if kind == "conversation":
|
||||
|
||||
# tag based
|
||||
if "deterministic" in kind:
|
||||
return PRESET_DETERMINISTIC
|
||||
elif "creative" in kind:
|
||||
return PRESET_DIVINE_INTELLECT
|
||||
elif "simple" in kind:
|
||||
return PRESET_SIMPLE_1
|
||||
elif "analytical" in kind:
|
||||
return PRESET_ANALYTICAL
|
||||
elif kind == "conversation":
|
||||
return PRESET_TALEMATE_CONVERSATION
|
||||
elif kind == "conversation_old":
|
||||
return PRESET_TALEMATE_CONVERSATION # Assuming old conversation uses the same preset
|
||||
@@ -120,63 +143,87 @@ def preset_for_kind(kind: str):
|
||||
elif kind == "edit_add_detail":
|
||||
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
|
||||
elif kind == "edit_fix_exposition":
|
||||
return PRESET_DIVINE_INTELLECT # Assuming fixing exposition uses the same preset as divine intellect
|
||||
return PRESET_DETERMINISTIC # Assuming fixing exposition uses the same preset as divine intellect
|
||||
elif kind == "edit_fix_continuity":
|
||||
return PRESET_DETERMINISTIC
|
||||
elif kind == "visualize":
|
||||
return PRESET_SIMPLE_1
|
||||
|
||||
else:
|
||||
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
|
||||
|
||||
|
||||
def max_tokens_for_kind(kind: str, total_budget: int):
|
||||
if kind == "conversation":
|
||||
return 75 # Example value, adjust as needed
|
||||
return 75
|
||||
elif kind == "conversation_old":
|
||||
return 75 # Example value, adjust as needed
|
||||
return 75
|
||||
elif kind == "conversation_long":
|
||||
return 300 # Example value, adjust as needed
|
||||
return 300
|
||||
elif kind == "conversation_select_talking_actor":
|
||||
return 30 # Example value, adjust as needed
|
||||
return 30
|
||||
elif kind == "summarize":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "analyze":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "analyze_creative":
|
||||
return 1024 # Example value, adjust as needed
|
||||
return 1024
|
||||
elif kind == "analyze_long":
|
||||
return 2048 # Example value, adjust as needed
|
||||
return 2048
|
||||
elif kind == "analyze_freeform":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "analyze_freeform_medium":
|
||||
return 192
|
||||
elif kind == "analyze_freeform_medium_short":
|
||||
return 128
|
||||
elif kind == "analyze_freeform_short":
|
||||
return 10 # Example value, adjust as needed
|
||||
return 10
|
||||
elif kind == "narrate":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "story":
|
||||
return 300 # Example value, adjust as needed
|
||||
return 300
|
||||
elif kind == "create":
|
||||
return min(
|
||||
1024, int(total_budget * 0.35)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(1024, int(total_budget * 0.35))
|
||||
elif kind == "create_concise":
|
||||
return min(
|
||||
400, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(400, int(total_budget * 0.25))
|
||||
elif kind == "create_precise":
|
||||
return min(
|
||||
400, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(400, int(total_budget * 0.25))
|
||||
elif kind == "create_short":
|
||||
return 25
|
||||
elif kind == "director":
|
||||
return min(
|
||||
192, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(192, int(total_budget * 0.25))
|
||||
elif kind == "director_short":
|
||||
return 25 # Example value, adjust as needed
|
||||
return 25
|
||||
elif kind == "director_yesno":
|
||||
return 2 # Example value, adjust as needed
|
||||
return 2
|
||||
elif kind == "edit_dialogue":
|
||||
return 100 # Example value, adjust as needed
|
||||
return 100
|
||||
elif kind == "edit_add_detail":
|
||||
return 200 # Example value, adjust as needed
|
||||
return 200
|
||||
elif kind == "edit_fix_exposition":
|
||||
return 1024 # Example value, adjust as needed
|
||||
return 1024
|
||||
elif kind == "edit_fix_continuity":
|
||||
return 512
|
||||
elif kind == "visualize":
|
||||
return 150
|
||||
# tag based
|
||||
elif "extensive" in kind:
|
||||
return 2048
|
||||
elif "long" in kind:
|
||||
return 1024
|
||||
elif "medium2" in kind:
|
||||
return 512
|
||||
elif "medium" in kind:
|
||||
return 192
|
||||
elif "short2" in kind:
|
||||
return 128
|
||||
elif "short" in kind:
|
||||
return 75
|
||||
elif "tiny2" in kind:
|
||||
return 25
|
||||
elif "tiny" in kind:
|
||||
return 10
|
||||
elif "yesno" in kind:
|
||||
return 2
|
||||
else:
|
||||
return 150 # Default value if none of the kinds match
|
||||
|
||||
35
src/talemate/client/remote.py
Normal file
@@ -0,0 +1,35 @@
|
||||
__all__ = ["RemoteServiceMixin"]
|
||||
|
||||
|
||||
class RemoteServiceMixin:
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
@@ -21,11 +21,13 @@ dotenv.load_dotenv()
|
||||
|
||||
runpod.api_key = load_config().get("runpod", {}).get("api_key", "")
|
||||
|
||||
TEXTGEN_IDENTIFIERS = ["textgen", "thebloke llms", "text-generation-webui"]
|
||||
|
||||
|
||||
def is_textgen_pod(pod):
|
||||
name = pod["name"].lower()
|
||||
|
||||
if "textgen" in name or "thebloke llms" in name:
|
||||
if any(identifier in name for identifier in TEXTGEN_IDENTIFIERS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -19,3 +19,29 @@ EDITOR = str(Prompt.get("editor.system"))
|
||||
WORLD_STATE = str(Prompt.get("world_state.system-analyst"))
|
||||
|
||||
SUMMARIZE = str(Prompt.get("summarizer.system"))
|
||||
|
||||
VISUALIZE = str(Prompt.get("visual.system"))
|
||||
|
||||
# CAREBEAR PROMPTS
|
||||
|
||||
ROLEPLAY_NO_DECENSOR = str(Prompt.get("conversation.system-no-decensor"))
|
||||
|
||||
NARRATOR_NO_DECENSOR = str(Prompt.get("narrator.system-no-decensor"))
|
||||
|
||||
CREATOR_NO_DECENSOR = str(Prompt.get("creator.system-no-decensor"))
|
||||
|
||||
DIRECTOR_NO_DECENSOR = str(Prompt.get("director.system-no-decensor"))
|
||||
|
||||
ANALYST_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-no-decensor"))
|
||||
|
||||
ANALYST_FREEFORM_NO_DECENSOR = str(
|
||||
Prompt.get("world_state.system-analyst-freeform-no-decensor")
|
||||
)
|
||||
|
||||
EDITOR_NO_DECENSOR = str(Prompt.get("editor.system-no-decensor"))
|
||||
|
||||
WORLD_STATE_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-no-decensor"))
|
||||
|
||||
SUMMARIZE_NO_DECENSOR = str(Prompt.get("summarizer.system-no-decensor"))
|
||||
|
||||
VISUALIZE_NO_DECENSOR = str(Prompt.get("visual.system-no-decensor"))
|
||||
|
||||
@@ -1,22 +1,47 @@
|
||||
import random
|
||||
import re
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
|
||||
from talemate.client.registry import register
|
||||
|
||||
log = structlog.get_logger("talemate.client.textgenwebui")
|
||||
|
||||
|
||||
class TextGeneratorWebuiClientDefaults(Defaults):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
@register()
|
||||
class TextGeneratorWebuiClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
finalizers: list[str] = [
|
||||
"finalize_llama3",
|
||||
"finalize_YI",
|
||||
]
|
||||
|
||||
client_type = "textgenwebui"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "TextGenWebUI"
|
||||
title: str = "Text-Generation-WebUI (ooba)"
|
||||
enable_api_auth: bool = True
|
||||
defaults: TextGeneratorWebuiClientDefaults = TextGeneratorWebuiClientDefaults()
|
||||
|
||||
@property
|
||||
def request_headers(self):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.api_key = kwargs.pop("api_key", "")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
@@ -27,25 +52,51 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
parameters["max_new_tokens"] = parameters["max_tokens"]
|
||||
parameters["stop"] = parameters["stopping_strings"]
|
||||
|
||||
# Half temperature on -Yi- models
|
||||
if (
|
||||
self.model_name
|
||||
and "-yi-" in self.model_name.lower()
|
||||
and parameters["temperature"] > 0.1
|
||||
):
|
||||
parameters["temperature"] = parameters["temperature"] / 2
|
||||
log.debug(
|
||||
"halfing temperature for -yi- model",
|
||||
temperature=parameters["temperature"],
|
||||
)
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def finalize_llama3(self, parameters: dict, prompt: str) -> tuple[str, bool]:
|
||||
|
||||
if "<|eot_id|>" not in prompt:
|
||||
return prompt, False
|
||||
|
||||
# llama3 instruct models need to add "<|eot_id|>", "<|end_of_text|>" to the stopping strings
|
||||
parameters["stopping_strings"] += ["<|eot_id|>", "<|end_of_text|>"]
|
||||
|
||||
# also needs to add `skip_special_tokens`= False to the parameters
|
||||
parameters["skip_special_tokens"] = False
|
||||
log.debug("finalizing llama3 instruct parameters", parameters=parameters)
|
||||
|
||||
if prompt.endswith("<|end_header_id|>"):
|
||||
# append two linebreaks
|
||||
prompt += "\n\n"
|
||||
log.debug("adjusting llama3 instruct prompt: missing linebreaks")
|
||||
|
||||
return prompt, True
|
||||
|
||||
def finalize_YI(self, parameters: dict, prompt: str) -> tuple[str, bool]:
|
||||
model_name = self.model_name.lower()
|
||||
# regex match for yi encased by non-word characters
|
||||
if not bool(re.search(r"[\-_]yi[\-_]", model_name)):
|
||||
return prompt, False
|
||||
|
||||
parameters["smoothing_factor"] = 0.1
|
||||
# also half the temperature
|
||||
parameters["temperature"] = max(0.1, parameters["temperature"] / 2)
|
||||
log.debug(
|
||||
"finalizing YI parameters",
|
||||
parameters=parameters,
|
||||
)
|
||||
return prompt, True
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/v1/internal/model/info", timeout=2
|
||||
f"{self.api_url}/v1/internal/model/info",
|
||||
timeout=2,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
raise Exception("Could not find model info (wrong api version?)")
|
||||
@@ -62,9 +113,6 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -72,7 +120,7 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
f"{self.api_url}/v1/completions",
|
||||
json=parameters,
|
||||
timeout=None,
|
||||
headers=headers,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
response_data = response.json()
|
||||
return response_data["choices"][0]["text"]
|
||||
@@ -92,3 +140,9 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
prompt_config["repetition_penalty"] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "api_key" in kwargs:
|
||||
self.api_key = kwargs.pop("api_key")
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .base import TalemateCommand
|
||||
from .cmd_autocomplete import *
|
||||
from .cmd_characters import *
|
||||
from .cmd_debug_tools import *
|
||||
from .cmd_dialogue import *
|
||||
@@ -10,6 +11,7 @@ from .cmd_inject import CmdInject
|
||||
from .cmd_list_scenes import CmdListScenes
|
||||
from .cmd_memget import CmdMemget
|
||||
from .cmd_memset import CmdMemset
|
||||
from .cmd_message_tools import *
|
||||
from .cmd_narrate import *
|
||||
from .cmd_rebuild_archive import CmdRebuildArchive
|
||||
from .cmd_remove_character import CmdRemoveCharacter
|
||||
|
||||
@@ -7,6 +7,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pydantic
|
||||
|
||||
from talemate.emit import Emitter, emit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -21,17 +23,23 @@ class TalemateCommand(Emitter, ABC):
|
||||
manager: CommandManager = None
|
||||
label: str = None
|
||||
sets_scene_unsaved: bool = True
|
||||
argument_cls: pydantic.BaseModel | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.scene = manager.scene
|
||||
self.manager = manager
|
||||
self.args = args
|
||||
self.setup_emitter(self.scene)
|
||||
|
||||
if self.argument_cls is not None:
|
||||
self.args = self.argument_cls(**kwargs)
|
||||
else:
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def is_command(cls, name):
|
||||
return name == cls.name or name in cls.aliases
|
||||
|
||||
81
src/talemate/commands/cmd_autocomplete.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from talemate.agents.creator.assistant import ContentGenerationContext
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import emit
|
||||
|
||||
__all__ = ["CmdAutocompleteDialogue", "CmdAutocomplete"]
|
||||
|
||||
|
||||
@register
|
||||
class CmdAutocompleteDialogue(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'autocomplete_dialogue' command
|
||||
"""
|
||||
|
||||
name = "autocomplete_dialogue"
|
||||
description = "Generate dialogue for an AI selected actor"
|
||||
aliases = ["acdlg"]
|
||||
|
||||
async def run(self):
|
||||
|
||||
input = self.args[0]
|
||||
if len(self.args) > 1:
|
||||
character_name = self.args[1]
|
||||
character = self.scene.get_character(character_name)
|
||||
else:
|
||||
character = self.scene.get_player_character()
|
||||
|
||||
creator = self.scene.get_helper("creator").agent
|
||||
|
||||
await creator.autocomplete_dialogue(input, character, emit_signal=True)
|
||||
|
||||
|
||||
@register
|
||||
class CmdAutocomplete(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'autocomplete' command
|
||||
"""
|
||||
|
||||
name = "autocomplete"
|
||||
description = "Generate information for an AI selected actor"
|
||||
aliases = ["ac"]
|
||||
argument_cls = ContentGenerationContext
|
||||
|
||||
async def run(self):
|
||||
|
||||
try:
|
||||
creator = self.scene.get_helper("creator").agent
|
||||
context_type, context_name = self.args.computed_context
|
||||
|
||||
if context_type == "dialogue":
|
||||
|
||||
if not self.args.character:
|
||||
character = self.scene.get_player_character()
|
||||
else:
|
||||
character = self.scene.get_character(self.args.character)
|
||||
|
||||
self.scene.log.info(
|
||||
"Running autocomplete dialogue",
|
||||
partial=self.args.partial,
|
||||
character=character,
|
||||
)
|
||||
await creator.autocomplete_dialogue(
|
||||
self.args.partial, character, emit_signal=True
|
||||
)
|
||||
return
|
||||
|
||||
# force length to 35
|
||||
self.args.length = 35
|
||||
self.scene.log.info("Running autocomplete context", args=self.args)
|
||||
completion = await creator.contextual_generate(self.args)
|
||||
self.scene.log.info("Autocomplete context complete", completion=completion)
|
||||
completion = (
|
||||
completion.replace(f"{context_name}: {self.args.partial}", "")
|
||||
.lstrip(".")
|
||||
.strip()
|
||||
)
|
||||
|
||||
emit("autocomplete_suggestion", completion)
|
||||
except Exception as e:
|
||||
self.scene.log.error("Error running autocomplete", error=str(e))
|
||||
emit("autocomplete_suggestion", "")
|
||||
@@ -79,7 +79,7 @@ class CmdDeactivateCharacter(TalemateCommand):
|
||||
)
|
||||
message = await narrator.action_to_narration(
|
||||
"narrate_character_exit",
|
||||
self.scene.get_character(character_name),
|
||||
character=self.scene.get_character(character_name),
|
||||
direction=direction,
|
||||
)
|
||||
self.narrator_message(message)
|
||||
@@ -159,7 +159,7 @@ class CmdActivateCharacter(TalemateCommand):
|
||||
)
|
||||
message = await narrator.action_to_narration(
|
||||
"narrate_character_entry",
|
||||
self.scene.get_character(character_name),
|
||||
character=self.scene.get_character(character_name),
|
||||
direction=direction,
|
||||
)
|
||||
self.narrator_message(message)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import structlog
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.prompts.base import set_default_sectioning_handler
|
||||
@@ -12,6 +15,8 @@ __all__ = [
|
||||
"CmdRunAutomatic",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.commands.cmd_debug_tools")
|
||||
|
||||
|
||||
@register
|
||||
class CmdDebugOn(TalemateCommand):
|
||||
@@ -144,3 +149,32 @@ class CmdSetContentContext(TalemateCommand):
|
||||
self.scene.context = context
|
||||
|
||||
self.emit("system", f"Content context set to {context}")
|
||||
|
||||
|
||||
@register
|
||||
class CmdDumpHistory(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'dump_history' command
|
||||
"""
|
||||
|
||||
name = "dump_history"
|
||||
description = "Dump the history of the scene"
|
||||
aliases = []
|
||||
|
||||
async def run(self):
|
||||
for entry in self.scene.history:
|
||||
log.debug("dump_history", entry=entry)
|
||||
|
||||
|
||||
@register
|
||||
class CmdDumpSceneSerialization(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'dump_scene_serialization' command
|
||||
"""
|
||||
|
||||
name = "dump_scene_serialization"
|
||||
description = "Dump the scene serialization"
|
||||
aliases = []
|
||||
|
||||
async def run(self):
|
||||
log.debug("dump_scene_serialization", serialization=self.scene.json)
|
||||
|
||||