Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abdfb1abbf | ||
|
|
2f07248211 | ||
|
|
9ae6fc822b | ||
|
|
5094359c4e | ||
|
|
28801b54bf | ||
|
|
4d69f0e837 | ||
|
|
d91b3f8042 | ||
|
|
03a0ab2fcf | ||
|
|
d860d62972 | ||
|
|
add4893939 |
141
README.md
@@ -1,35 +1,41 @@
|
||||
# Talemate
|
||||
|
||||
Allows you to play roleplay scenarios with large language models.
|
||||
Roleplay with AI with a focus on strong narration and consistent world and game state tracking.
|
||||
|
||||
|
||||
|||
|
||||
|||
|
||||
|------------------------------------------|------------------------------------------|
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio. 0.18.0 also adds support for generic OpenAI api implementations, but generation quality on that will vary.**
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, Anthropic, mistral.ai, self-hosted text-generation-webui and LMStudio. 0.18.0 also adds support for generic OpenAI api implementations, but generation quality on that will vary.**
|
||||
|
||||
This means you need to either have:
|
||||
- an [OpenAI](https://platform.openai.com/overview) api key
|
||||
- setup local (or remote via runpod) LLM inference via:
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
- Any other OpenAI api implementation that implements the v1/completions endpoint
|
||||
- tested llamacpp with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
Officially supported APIs:
|
||||
- [OpenAI](https://platform.openai.com/overview)
|
||||
- [Anthropic](https://www.anthropic.com/)
|
||||
- [mistral.ai](https://mistral.ai/)
|
||||
|
||||
Officially supported self-hosted APIs:
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
|
||||
Generic OpenAI api implementations (tested and confirmed working):
|
||||
- [DeepInfra](https://deepinfra.com/) - see [instructions](https://github.com/vegu-ai/talemate/issues/78#issuecomment-1986884304)
|
||||
- [llamacpp](https://github.com/ggerganov/llama.cpp) with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
|
||||
## Current features
|
||||
|
||||
- responive modern ui
|
||||
- responsive modern ui
|
||||
- agents
|
||||
- conversation: handles character dialogue
|
||||
- narration: handles narrative exposition
|
||||
- summarization: handles summarization to compress context while maintain history
|
||||
- summarization: handles summarization to compress context while maintaining history
|
||||
- director: can be used to direct the story / characters
|
||||
- editor: improves AI responses (very hit and miss at the moment)
|
||||
- world state: generates world snapshot and handles passage of time (objects and characters)
|
||||
- creator: character / scenario creator
|
||||
- tts: text to speech via elevenlabs, coqui studio, coqui local
|
||||
- tts: text to speech via elevenlabs, OpenAI or local tts
|
||||
- visual: stable-diffusion client for in place visual generation via AUTOMATIC1111, ComfyUI or OpenAI
|
||||
- multi-client support (agents can be connected to separate APIs)
|
||||
- long term memory
|
||||
- chromadb integration
|
||||
@@ -54,7 +60,6 @@ Kinda making it up as i go along, but i want to lean more into gameplay through
|
||||
|
||||
In no particular order:
|
||||
|
||||
|
||||
- Extension support
|
||||
- modular agents and clients
|
||||
- Improved world state
|
||||
@@ -68,7 +73,27 @@ In no particular order:
|
||||
- objectives
|
||||
- quests
|
||||
- win / lose conditions
|
||||
- stable-diffusion client for in place visual generation
|
||||
|
||||
|
||||
# Instructions
|
||||
|
||||
Please read the documents in the `docs` folder for more advanced configuration and usage.
|
||||
|
||||
- [Quickstart](#quickstart)
|
||||
- [Installation](#installation)
|
||||
- [Connecting to an LLM](#connecting-to-an-llm)
|
||||
- [Text-generation-webui](#text-generation-webui)
|
||||
- [Recommended Models](#recommended-models)
|
||||
- [OpenAI / mistral.ai / Anthropic](#openai)
|
||||
- [DeepInfra via OpenAI Compatible client](#deepinfra-via-openai-compatible-client)
|
||||
- [Ready to go](#ready-to-go)
|
||||
- [Load the introductory scenario "Infinity Quest"](#load-the-introductory-scenario-infinity-quest)
|
||||
- [Loading character cards](#loading-character-cards)
|
||||
- [Text-to-Speech (TTS)](docs/tts.md)
|
||||
- [Visual Generation](docs/visual.md)
|
||||
- [ChromaDB (long term memory) configuration](docs/chromadb.md)
|
||||
- [Runpod Integration](docs/runpod.md)
|
||||
- [Prompt template overrides](docs/templates.md)
|
||||
|
||||
# Quickstart
|
||||
|
||||
@@ -99,33 +124,67 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
|
||||
1. Start the backend: `python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
|
||||
1. Open a new terminal, navigate to the `talemate_frontend` directory, and start the frontend server by running `npm run serve`.
|
||||
|
||||
## Connecting to an LLM
|
||||
# Connecting to an LLM
|
||||
|
||||
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
|
||||
|
||||

|
||||
|
||||
### Text-generation-webui
|
||||

|
||||
|
||||
## Text-generation-webui
|
||||
|
||||
> :warning: As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
|
||||
|
||||
In the modal if you're planning to connect to text-generation-webui, you can likely leave everything as is and just click Save.
|
||||
|
||||

|
||||

|
||||
|
||||
### Specifying the correct prompt template
|
||||
|
||||
#### Recommended Models
|
||||
For good results it is **vital** that the correct prompt template is specified for whichever model you have loaded.
|
||||
|
||||
Any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
|
||||
Talemate does come with a set of pre-defined templates for some popular models, but going forward, due to the sheet number of models released every day, understanding and specifying the correct prompt template is something you should familiarize yourself with.
|
||||
|
||||
If the text-gen-webui client shows a yellow triangle next to it, it means that the prompt template is not set, and it is currently using the default `VICUNA` style prompt template.
|
||||
|
||||

|
||||
|
||||
Click the two cogwheels to the right of the triangle to open the client settings.
|
||||
|
||||

|
||||
|
||||
You can first try by clicking the `DETERMINE VIA HUGGINGFACE` button, depending on the model's README file, it may be able to determine the correct prompt template for you. (basically the readme needs to contain an example of the template)
|
||||
|
||||
If that doesn't work, you can manually select the prompt template from the dropdown.
|
||||
|
||||
In the case for `bartowski_Nous-Hermes-2-Mistral-7B-DPO-exl2_8_0` that is `ChatML` - select it from the dropdown and click `Save`.
|
||||
|
||||

|
||||
|
||||
### Recommended Models
|
||||
|
||||
As of 2024.03.07 my personal regular drivers (the ones i test with) are:
|
||||
|
||||
- Kunoichi-7B
|
||||
- sparsetral-16x7B
|
||||
- Nous-Hermes-2-Mistral-7B-DPO
|
||||
- brucethemoose_Yi-34B-200K-RPMerge
|
||||
- dolphin-2.7-mixtral-8x7b
|
||||
- rAIfle_Verdict-8x7B
|
||||
- Mixtral-8x7B-instruct
|
||||
|
||||
That said, any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
|
||||
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/18yp9u4/llm_comparisontest_api_edition_gpt4_vs_gemini_vs/
|
||||
|
||||
## OpenAI / mistral.ai / Anthropic
|
||||
|
||||
### OpenAI
|
||||
The setup is the same for all three, the example below is for OpenAI.
|
||||
|
||||
If you want to add an OpenAI client, just change the client type and select the apropriate model.
|
||||
|
||||

|
||||

|
||||
|
||||
If you are setting this up for the first time, you should now see the client, but it will have a red dot next to it, stating that it requires an API key.
|
||||
|
||||
@@ -133,17 +192,33 @@ If you are setting this up for the first time, you should now see the client, bu
|
||||
|
||||
Click the `SET API KEY` button. This will open a modal where you can enter your API key.
|
||||
|
||||

|
||||

|
||||
|
||||
Click `Save` and after a moment the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
## DeepInfra via OpenAI Compatible client
|
||||
|
||||
You can use the OpenAI compatible client to connect to [DeepInfra](https://deepinfra.com/).
|
||||
|
||||

|
||||
|
||||
```
|
||||
API URL: https://api.deepinfra.com/v1/openai
|
||||
```
|
||||
|
||||
Models on DeepInfra that work well with Talemate:
|
||||
|
||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://deepinfra.com/mistralai/Mixtral-8x7B-Instruct-v0.1) (max context 32k, 8k recommended)
|
||||
- [cognitivecomputations/dolphin-2.6-mixtral-8x7b](https://deepinfra.com/cognitivecomputations/dolphin-2.6-mixtral-8x7b) (max context 32k, 8k recommended)
|
||||
- [lizpreciatior/lzlv_70b_fp16_hf](https://deepinfra.com/lizpreciatior/lzlv_70b_fp16_hf) (max context 4k)
|
||||
|
||||
## Ready to go
|
||||
|
||||
You will know you are good to go when the client and all the agents have a green dot next to them.
|
||||
|
||||

|
||||

|
||||
|
||||
## Load the introductory scenario "Infinity Quest"
|
||||
|
||||
@@ -163,14 +238,4 @@ Expand the "Load" menu in the top left corner and either click on "Upload a char
|
||||
|
||||
Once a character is uploaded, talemate may actually take a moment because it needs to convert it to a talemate format and will also run additional LLM prompts to generate character attributes and world state.
|
||||
|
||||
Make sure you save the scene after the character is loaded as it can then be loaded as normal talemate scenario in the future.
|
||||
|
||||
## Further documentation
|
||||
|
||||
Please read the documents in the `docs` folder for more advanced configuration and usage.
|
||||
|
||||
- [Prompt template overrides](docs/templates.md)
|
||||
- [Text-to-Speech (TTS)](docs/tts.md)
|
||||
- [ChromaDB (long term memory)](docs/chromadb.md)
|
||||
- [Runpod Integration](docs/runpod.md)
|
||||
- Creative mode
|
||||
Make sure you save the scene after the character is loaded as it can then be loaded as normal talemate scenario in the future.
|
||||
@@ -48,6 +48,7 @@ game:
|
||||
# embeddings: instructor
|
||||
# instructor_device: cuda
|
||||
# instructor_model: hkunlp/instructor-xl
|
||||
# openai_model: text-embedding-3-small
|
||||
|
||||
## Remote LLMs
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ Then add the following to `config.yaml` for chromadb:
|
||||
```yaml
|
||||
chromadb:
|
||||
embeddings: openai
|
||||
openai_model: text-embedding-3-small
|
||||
```
|
||||
|
||||
**Note**: As with everything openai, using this isn't free. It's way cheaper than their text completion though. ALSO - if you send super explicit content they may flag / ban your key, so keep that in mind (i hear they usually send warnings first though), and always monitor your usage on their dashboard.
|
||||
48
docs/dev/agents/example/test/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from talemate.agents.base import Agent, AgentAction
|
||||
from talemate.agents.registry import register
|
||||
from talemate.events import GameLoopEvent
|
||||
import talemate.emit.async_signals
|
||||
from talemate.emit import emit
|
||||
|
||||
@register()
|
||||
class TestAgent(Agent):
|
||||
|
||||
agent_type = "test"
|
||||
verbose_name = "Test"
|
||||
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"test": AgentAction(
|
||||
enabled=True,
|
||||
label="Test",
|
||||
description="Test",
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called on the beginning of every game loop
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
emit("status", status="info", message="Annoying you with a test message every game loop.")
|
||||
130
docs/dev/client/example/runpod_vllm/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
An attempt to write a client against the runpod serverless vllm worker.
|
||||
|
||||
This is close to functional, but since runpod serverless gpu availability is currently terrible, i have
|
||||
been unable to properly test it.
|
||||
|
||||
Putting it here for now since i think it makes a decent example of how to write a client against a new service.
|
||||
"""
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import runpod
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.runpod_vllm")
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
runpod_id: str = ""
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
runpod_id: str = ""
|
||||
|
||||
@register()
|
||||
class RunPodVLLMClient(ClientBase):
|
||||
client_type = "runpod_vllm"
|
||||
conversation_retries = 5
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title: str = "Runpod VLLM"
|
||||
name_prefix: str = "Runpod VLLM"
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"runpod_id": ExtraField(
|
||||
name="runpod_id",
|
||||
type="text",
|
||||
label="Runpod ID",
|
||||
required=True,
|
||||
description="The Runpod ID to connect to.",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, model=None, runpod_id=None, **kwargs):
|
||||
self.model_name = model
|
||||
self.runpod_id = runpod_id
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
|
||||
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
|
||||
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
return self.model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
prompt = prompt.strip()
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
|
||||
|
||||
run_request = await endpoint.run({
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
}
|
||||
#"parameters": parameters
|
||||
})
|
||||
|
||||
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
|
||||
status = await run_request.status()
|
||||
log.debug("generate", status=status)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
status = await run_request.status()
|
||||
|
||||
log.debug("generate", status=status)
|
||||
|
||||
response = await run_request.output()
|
||||
|
||||
log.debug("generate", response=response)
|
||||
|
||||
return response["choices"][0]["tokens"][0]
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
if "runpod_id" in kwargs:
|
||||
self.api_auth = kwargs["runpod_id"]
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
self.set_client(**kwargs)
|
||||
67
docs/dev/client/example/test/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pydantic
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
|
||||
@register()
|
||||
class TestClient(ClientBase):
|
||||
client_type = "test"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "test"
|
||||
title: str = "Test"
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
|
||||
"""
|
||||
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
|
||||
|
||||
This method is called before the prompt is sent to the client, and it allows the client to remove
|
||||
any parameters that it doesn't support.
|
||||
"""
|
||||
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
"""
|
||||
This should return the name of the model that is being used.
|
||||
"""
|
||||
|
||||
return "Mock test model"
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
BIN
docs/img/0.19.0/Screenshot_15.png
Normal file
|
After Width: | Height: | Size: 418 KiB |
BIN
docs/img/0.19.0/Screenshot_16.png
Normal file
|
After Width: | Height: | Size: 413 KiB |
BIN
docs/img/0.19.0/Screenshot_17.png
Normal file
|
After Width: | Height: | Size: 364 KiB |
BIN
docs/img/0.20.0/comfyui-base-workflow.png
Normal file
|
After Width: | Height: | Size: 128 KiB |
BIN
docs/img/0.20.0/visual-config-a1111.png
Normal file
|
After Width: | Height: | Size: 32 KiB |
BIN
docs/img/0.20.0/visual-config-comfyui.png
Normal file
|
After Width: | Height: | Size: 34 KiB |
BIN
docs/img/0.20.0/visual-config-openai.png
Normal file
|
After Width: | Height: | Size: 30 KiB |
BIN
docs/img/0.20.0/visual-queue.png
Normal file
|
After Width: | Height: | Size: 933 KiB |
BIN
docs/img/0.20.0/visualize-scene-tools.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
BIN
docs/img/0.20.0/visualizer-busy.png
Normal file
|
After Width: | Height: | Size: 3.5 KiB |
BIN
docs/img/0.20.0/visualizer-ready.png
Normal file
|
After Width: | Height: | Size: 2.9 KiB |
BIN
docs/img/0.20.0/visualze-new-images.png
Normal file
|
After Width: | Height: | Size: 1.8 KiB |
BIN
docs/img/0.21.0/deepinfra-setup.png
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
docs/img/0.21.0/no-clients.png
Normal file
|
After Width: | Height: | Size: 7.1 KiB |
BIN
docs/img/0.21.0/openai-add-api-key.png
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
docs/img/0.21.0/openai-setup.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
docs/img/0.21.0/prompt-template-default.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
docs/img/0.21.0/ready-to-go.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/img/0.21.0/select-prompt-template.png
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
docs/img/0.21.0/selected-prompt-template.png
Normal file
|
After Width: | Height: | Size: 49 KiB |
BIN
docs/img/0.21.0/text-gen-webui-setup.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
15
docs/tts.md
@@ -17,21 +17,6 @@ elevenlabs:
|
||||
api_key: <YOUR_ELEVENLABS_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Coqui TTS
|
||||
|
||||
To use Coqui TTS with Talemate, follow these steps:
|
||||
|
||||
1. Visit [Coqui](https://app.coqui.ai) and sign up for an account.
|
||||
2. Go to the [account page](https://app.coqui.ai/account) and scroll to the bottom to find your API key.
|
||||
3. In the `config.yaml` file, under the `coqui` section, set the `api_key` field with your Coqui API key.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
coqui:
|
||||
api_key: <YOUR_COQUI_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Local TTS API
|
||||
|
||||
For running a local TTS API, Talemate requires specific dependencies to be installed.
|
||||
|
||||
117
docs/visual.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# Visual Agent
|
||||
|
||||
The visual agent currently allows for some bare bones visual generation using various stable-diffusion APIs. This is early development and experimental.
|
||||
|
||||
Its important to note that the visualization agent actually specifies two clients. One is the backend for the visual generation, and the other is the text generation client to use for prompt generation.
|
||||
|
||||
The client for prompt generation can be assigned to the agent as you would for any other agent. The client for visual generation is assigned in the Visualizer config.
|
||||
|
||||
## Index
|
||||
|
||||
- [OpenAI](#openai)
|
||||
- [AUTOMATIC1111](#automatic1111)
|
||||
- [ComfyUI](#comfyui)
|
||||
- [How to use](#how-to-use)
|
||||
|
||||
## OpenAI
|
||||
|
||||
Most straightforward to use, as it runs on the OpenAI API. You will need to have an API key and set it in the application config.
|
||||
|
||||

|
||||
|
||||
Then open the Visualizer config by clicking the agent's name in the agent list and choose `OpenAI` as the backend.
|
||||
|
||||

|
||||
|
||||
Note: `Client` here refers to the text-generation client to use for prompt generation. While `Backend` refers to the visual generation backend. You are **NOT** required to use the OpenAI client for prompt generation even if you are using the OpenAI backend for image generation.
|
||||
|
||||
## AUTOMATIC1111
|
||||
|
||||
This requires you to setup a local instance of the AUTOMATIC1111 API. Follow the instructions from their [GitHub](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to get it running.
|
||||
|
||||
Once you have it running, you will want to adjust the `webui-user.bat` in the AUTOMATIC1111 directory to include the following command arguments:
|
||||
|
||||
```bat
|
||||
set COMMANDLINE_ARGS=--api --listen --port 7861
|
||||
```
|
||||
|
||||
Then run the `webui-user.bat` to start the API.
|
||||
|
||||
Once your AUTOAMTIC1111 API is running (check with your browser) you can set the Visualizer config to use the `AUTOMATIC1111` backend
|
||||
|
||||

|
||||
|
||||
#### Extra Configuration
|
||||
|
||||
- `api url`: the url of the API, usually `http://localhost:7861`
|
||||
- `steps`: render steps
|
||||
- `model type`: sdxl or sd1.5 - this will dictate the resolution of the image generation and actually matters for the quality so make sure this is set to the correct model type for the model you are using.
|
||||
|
||||
## ComfyUI
|
||||
|
||||
This requires you to setup a local instance of the ComfyUI API. Follow the instructions from their [GitHub](https://github.com/comfyanonymous/ComfyUI) to get it running.
|
||||
|
||||
Once you're setup, copy their `start.bat` file to a new `start-listen.bat` file and change the contents to.
|
||||
|
||||
```bat
|
||||
call venv\Scripts\activate
|
||||
call python main.py --port 8188 --listen 0.0.0.0
|
||||
```
|
||||
|
||||
Then run the `start-listen.bat` to start the API.
|
||||
|
||||
Once your ComfyUI API is running (check with your browser) you can set the Visualizer config to use the `ComfyUI` backend.
|
||||
|
||||

|
||||
|
||||
### Extra Configuration
|
||||
|
||||
- `api url`: the url of the API, usually `http://localhost:8188`
|
||||
- `workflow`: the workflow file to use. This is a comfyui api workflow file that needs to exist in `./templates/comfyui-workflows` inside the talemate directory. Talemate provides two very barebones workflows with `default-sdxl.json` and `default-sd15.json`. You can create your own workflows and place them in this directory to use them. :warning: The workflow file must be generated using the API Workflow export not the UI export. Please refer to their documentation for more information.
|
||||
- `checkpoint`: the model to use - this will load a list of all available models in your comfyui instance. Select which one you want to use for the image generation.
|
||||
|
||||
### Custom Workflows
|
||||
|
||||
When creating custom workflows for ideal compatibility with Talemate, ensure the following.
|
||||
|
||||
- A `CheckpointLoaderSimple` node named `Talemate Load Checkpoint`
|
||||
- A `EmptyLatentImage` node name `Talemate Resolution`
|
||||
- A `ClipTextEncode` node named `Talemate Positive Prompt`
|
||||
- A `ClipTextEncode` node named `Talemate Negative Prompt`
|
||||
- A `SaveImage` node at the end of the workflow.
|
||||
|
||||

|
||||
|
||||
## How to use
|
||||
|
||||
Once you're done setting up the visualizer agent should have a green dot next to it and display both the selected image generation backend and the selected prompt generation client.
|
||||
|
||||

|
||||
|
||||
Your hotbar should then also enable the visualization menu for you to use (once you have a scene loaded).
|
||||
|
||||

|
||||
|
||||
Right now you can generate a portrait for any NPC in the scene or a background image for the scene itself.
|
||||
|
||||
Image generation by default will actually happen in the background, allowing you to continue using Talemate while the image is being generated.
|
||||
|
||||
You can tell if an image is being generated by the blueish spinner next to the visualization agent.
|
||||
|
||||

|
||||
|
||||
Once the image is generated, it will be avaible for you to view via the visual queue button on top of the screen.
|
||||
|
||||

|
||||
|
||||
Click it to open the visual queue and view the generated images.
|
||||
|
||||

|
||||
|
||||
### Character Portrait
|
||||
|
||||
For character potraits you can chose whether or not to replace the main portrait for the character (the one being displated in the left sidebar when a talemate scene is active).
|
||||
|
||||
### Background Image
|
||||
|
||||
Right now there is nothing to do with the background image, other than to view it in the visual queue. More functionality will be added in the future.
|
||||
2482
poetry.lock
generated
@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.18.2"
|
||||
version = "0.21.0"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
@@ -20,7 +20,7 @@ jinja2 = "^3.0"
|
||||
openai = ">=1"
|
||||
requests = "^2.26"
|
||||
colorama = ">=0.4.6"
|
||||
Pillow = "^9.5"
|
||||
Pillow = ">=9.5"
|
||||
httpx = "<1"
|
||||
piexif = "^1.1"
|
||||
typing-inspect = "0.8.0"
|
||||
@@ -39,6 +39,7 @@ thefuzz = ">=0.20.0"
|
||||
tiktoken = ">=0.5.1"
|
||||
nltk = ">=3.8.1"
|
||||
huggingface-hub = ">=0.20.2"
|
||||
anthropic = ">=0.19.1"
|
||||
|
||||
# ChromaDB
|
||||
chromadb = ">=0.4.17,<1"
|
||||
|
||||
|
After Width: | Height: | Size: 1.6 MiB |
@@ -98,6 +98,7 @@
|
||||
}
|
||||
],
|
||||
"immutable_save": true,
|
||||
"experimental": true,
|
||||
"goal": null,
|
||||
"goals": [],
|
||||
"context": "an epic sci-fi adventure aimed at an adult audience.",
|
||||
@@ -109,10 +110,10 @@
|
||||
"variables": {}
|
||||
},
|
||||
"assets": {
|
||||
"cover_image": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
|
||||
"cover_image": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
|
||||
"assets": {
|
||||
"52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df": {
|
||||
"id": "52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
|
||||
"e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404": {
|
||||
"id": "e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
|
||||
"file_type": "png",
|
||||
"media_type": "image/png"
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
{%- set _ = emit_system("warning", "This is a dynamic scenario generation experiment for Infinity Quest. It will likely require a strong LLM to generate something coherent. GPT-4 or 34B+ if local. Temper your expectations.") -%}
|
||||
|
||||
{#- emit status update to the UX -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [1/3]", as_scene_message=True) -%}
|
||||
|
||||
{#- thematic tags will be used to randomize generation -#}
|
||||
{%- set tags = thematic_generator.generate("color", "state_of_matter", "scifi_trope") -%}
|
||||
@@ -17,17 +17,17 @@
|
||||
|
||||
|
||||
{#- generate introductory text -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [2/3]", as_scene_message=True) -%}
|
||||
{%- set tmpl__scenario_intro = render_template('generate-scenario-intro', premise=instr__premise) %}
|
||||
{%- set instr__intro = "*"+render_and_request(tmpl__scenario_intro)+"*" -%}
|
||||
|
||||
{#- generate win conditions -#}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]") -%}
|
||||
{%- set _ = emit_status("busy", "Generating scenario ... [3/3]", as_scene_message=True) -%}
|
||||
{%- set tmpl__win_conditions = render_template('generate-win-conditions', premise=instr__premise) %}
|
||||
{%- set instr__win_conditions = render_and_request(tmpl__win_conditions) -%}
|
||||
|
||||
{#- emit status update to the UX -#}
|
||||
{%- set status = emit_status("info", "Scenario ready.") -%}
|
||||
{%- set status = emit_status("success", "Scenario ready.", as_scene_message=True) -%}
|
||||
|
||||
{# set gamestate variables #}
|
||||
{%- set _ = game_state.set_var("instr.premise", instr__premise, commit=True) -%}
|
||||
|
||||
|
After Width: | Height: | Size: 1.7 MiB |
52
scenes/simulation-suite/simulation-suite.json
Normal file
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"name": "Simulation Suite",
|
||||
"environment": "scene",
|
||||
"immutable_save": true,
|
||||
"restore_from": "simulation-suite.json",
|
||||
"experimental": true,
|
||||
"help": "Address the computer by starting your statements with 'Computer, ' followed by an instruction.\n\nExamples:\n'Computer, i would like to experience an adventure on a derelict space station'\n'Computer, add a horrific alien creature that is chasing me.'",
|
||||
"description": "",
|
||||
"intro": "*You have entered the simulation suite. No simulation is currently active and you are in a non-descript space with paneled walls surrounding you. The control panel next to you is pulsating with a green light, indicating readiness to receive a prompt to start the simulation.*",
|
||||
"archived_history": [],
|
||||
"history": [],
|
||||
"ts": "PT1S",
|
||||
"characters": [
|
||||
{
|
||||
"name": "You",
|
||||
"gender": "unknown",
|
||||
"color": "cornflowerblue",
|
||||
"base_attributes": {},
|
||||
"is_player": true
|
||||
}
|
||||
],
|
||||
"context": "a simulated experience",
|
||||
"game_state": {
|
||||
"ops":{
|
||||
"run_on_start": true,
|
||||
"always_direct": true
|
||||
},
|
||||
"variables": {}
|
||||
},
|
||||
"world_state": {
|
||||
"character_name_mappings": {
|
||||
"You": [
|
||||
"user",
|
||||
"player",
|
||||
"player character",
|
||||
"user character",
|
||||
"the user",
|
||||
"the player"
|
||||
]
|
||||
}
|
||||
},
|
||||
"assets": {
|
||||
"cover_image": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
|
||||
"assets": {
|
||||
"4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103": {
|
||||
"id": "4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
|
||||
"file_type": "png",
|
||||
"media_type": "image/png"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
118
scenes/simulation-suite/templates/computer.jinja2
Normal file
@@ -0,0 +1,118 @@
|
||||
<|SECTION:CONTEXT|>
|
||||
{% set scene_history=scene.context_history(budget=1024) %}
|
||||
{% for scene_context in scene_history -%}
|
||||
{{ loop.index }}. {{ scene_context }}
|
||||
{% endfor %}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:FUNCTIONS|>
|
||||
The player has instructed the computer to alter the current simulation.
|
||||
|
||||
You have access to the following functions, you can call as many as you want to fulfill the player's requests.
|
||||
|
||||
You must at least call one of the following functions:
|
||||
|
||||
- change_environment
|
||||
- add_ai_character
|
||||
- change_ai_character
|
||||
- remove_ai_character
|
||||
- set_player_persona
|
||||
- set_player_name
|
||||
- end_simulation
|
||||
- answer_question
|
||||
|
||||
`add_ai_character` and `change_ai_character` are exclusive if they are targeting the same character.
|
||||
|
||||
Set the player persona at the beginning of a new simulation or if the player requests a change.
|
||||
|
||||
Only end the simulation if the player requests it explicitly.
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:EXAMPLES|>
|
||||
Request: Computer, I want to be on a mountain top
|
||||
```simulation-stack
|
||||
change_environment("mountain top")
|
||||
set_player_persona("mountain climber")
|
||||
set_player_name("Hank")
|
||||
```
|
||||
|
||||
Request: Computer, I want to be more muscular and taller
|
||||
```simulation-stack
|
||||
set_player_persona("make player more muscular and taller")
|
||||
```
|
||||
|
||||
Request: Computer, the building should be on fire
|
||||
```simulation-stack
|
||||
change_environment("building on fire")
|
||||
```
|
||||
|
||||
Request: Computer, a rocket hits the building and George is now injured
|
||||
```simulation-stack
|
||||
change_environment("building on fire")
|
||||
change_ai_character("George is injured")
|
||||
```
|
||||
|
||||
Request: Computer, I want to experience a rollercoaster ride with a friend
|
||||
```simulation-stack
|
||||
change_environment("theme park, riding a rollercoaster")
|
||||
set_player_persona("young female experiencing rollercoaster ride")
|
||||
set_player_name("Susanne")
|
||||
add_ai_character("a female friend of player named Sarah")
|
||||
```
|
||||
|
||||
Request: Computer, I want to experience the international space station
|
||||
```simulation-stack
|
||||
change_environment("international space station")
|
||||
set_player_persona("astronaut experiencing first trip to ISS")
|
||||
set_player_name("George")
|
||||
add_ai_character("astronaut named Henry")
|
||||
```
|
||||
|
||||
Request: Computer, remove the goblin and add an elven woman instead
|
||||
```simulation-stack
|
||||
remove_ai_character("goblin")
|
||||
add_ai_character("elven woman named Elune")
|
||||
```
|
||||
|
||||
Request: Computer, change the skiing instructor to be older.
|
||||
```simulation-stack
|
||||
change_ai_character("make skiing instructor older")
|
||||
```
|
||||
|
||||
Request: Computer, change my grandma to my grandpa
|
||||
```simulation-stack
|
||||
remove_ai_character("grandma")
|
||||
add_ai_character("grandpa named Steven")
|
||||
```
|
||||
|
||||
Request: Computer, remove the skiing instructor and add my friend instead.
|
||||
```simulation-stack
|
||||
remove_ai_character("skiing instructor")
|
||||
add_ai_character("player's friend named Tara")
|
||||
```
|
||||
|
||||
Request: Computer, replace the skiing instructor with my friend.
|
||||
```simulation-stack
|
||||
remove_ai_character("skiing instructor")
|
||||
add_ai_character("player's friend named Lisa")
|
||||
```
|
||||
|
||||
Request: Computer, I want to end the simulation
|
||||
```simulation-stack
|
||||
end_simulation("simulation ended")
|
||||
```
|
||||
|
||||
Request: Computer, shut down the simulation
|
||||
```simulation-stack
|
||||
end_simulation("simulation ended")
|
||||
```
|
||||
|
||||
Request: Computer, what do you know about the game of thrones?
|
||||
```simulation-stack
|
||||
answer_question("what do you know about the game of thrones?")
|
||||
```
|
||||
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
Respond with the simulation stack for the following request:
|
||||
|
||||
Request: {{ player_instruction }}
|
||||
{{ bot_token }}```simulation-stack
|
||||
177
scenes/simulation-suite/templates/instructions.jinja2
Normal file
@@ -0,0 +1,177 @@
|
||||
{% set update_world_state = False %}
|
||||
{% set _ = debug("HOLODECK SIMULATION") -%}
|
||||
{% set player_character = scene.get_player_character() %}
|
||||
{% set player_message = scene.last_player_message() %}
|
||||
{% set last_processed = game_state.get_var('instr.last_processed', -1) %}
|
||||
{% set player_message_is_instruction = (player_message and player_message.raw.lower().startswith("computer") and not player_message.hidden) and not player_message.raw.lower().strip() == "computer" and not last_processed >= player_message.id %}
|
||||
{% set simulation_reset = False %}
|
||||
{% if not game_state.has_var('instr.simulation_stopped') %}
|
||||
{# simulation NOT started #}
|
||||
|
||||
{# get last player instruction #}
|
||||
{% if player_message_is_instruction %}
|
||||
{# player message exists #}
|
||||
|
||||
{#% set _ = agent_action("narrator", "action_to_narration", action_name="paraphrase", narration="The computer is processing the request, please wait a moment.", emit_message=True) %#}
|
||||
|
||||
{% set calls = render_and_request(render_template("computer", player_instruction=player_message.raw), dedupe_enabled=False) %}
|
||||
|
||||
{% set _ = debug("HOLODECK simulation calls", calls=calls ) %}
|
||||
{% set processed = make_list() %}
|
||||
|
||||
{% for call in calls.split("\n") %}
|
||||
{% set _ = debug("CALL", call=call, processed=processed) %}
|
||||
{% set inject = "The computer executes the function `"+call+"`" %}
|
||||
{% if call.strip().startswith('change_environment') %}
|
||||
{# change environment #}
|
||||
{% set _ = processed.append(call) %}
|
||||
|
||||
{% elif call.strip().startswith("answer_question") %}
|
||||
{# answert a query #}
|
||||
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="progress_story", narrative_direction="The computer calls the following function:\n"+call+"\nand answers the player's question.", emit_message=True) %}
|
||||
|
||||
|
||||
{% elif call.strip().startswith("set_player_persona") %}
|
||||
{# treansform player #}
|
||||
{% set _ = emit_status("busy", "Simulation suite altering user persona.", as_scene_message=True) %}
|
||||
|
||||
{% set character_attributes = agent_action("world_state", "extract_character_sheet", name=player_character.name, text=player_message.raw)%}
|
||||
|
||||
{% set _ = player_character.update(base_attributes=character_attributes) %}
|
||||
|
||||
{% set character_description = agent_action("creator", "determine_character_description", character=player_character) %}
|
||||
|
||||
{% set _ = player_character.update(description=character_description) %}
|
||||
|
||||
{% set _ = debug("HOLODECK transform player", attributes=character_attributes, description=character_description) %}
|
||||
{% set _ = processed.append(call) %}
|
||||
{% elif call.strip().startswith("set_player_name") %}
|
||||
{# change player name #}
|
||||
{% set _ = emit_status("busy", "Simulation suite adjusting user idenity.", as_scene_message=True) %}
|
||||
{% set character_name = agent_action("creator", "determine_character_name", character_name=inject+" - What is a fitting name for the player persona? Respond with the current name if it still fits.") %}
|
||||
|
||||
{% set _ = debug("HOLODECK player name", character_name=character_name) %}
|
||||
|
||||
{% if character_name != player_character.name %}
|
||||
{% set _ = processed.append(call) %}
|
||||
{% set _ = player_character.rename(character_name) %}
|
||||
{% endif %}
|
||||
{% elif call.strip().startswith("add_ai_character") %}
|
||||
{# add new npc #}
|
||||
|
||||
{% set _ = emit_status("busy", "Simulation suite adding character.", as_scene_message=True) %}
|
||||
{% set character_name = agent_action("creator", "determine_character_name", character_name=inject+" - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.") %}
|
||||
|
||||
{% set _ = emit_status("busy", "Simulation suite adding character: "+character_name, as_scene_message=True) %}
|
||||
{% set _ = debug("HOLODECK add npc", name=character_name)%}
|
||||
{% set npc = agent_action("director", "persist_character", name=character_name, content=player_message.raw )%}
|
||||
{% set _ = agent_action("world_state", "manager", action_name="add_detail_reinforcement", character_name=npc.name, question="Goal", instructions="Generate a goal for "+npc.name+", based on the user's chosen simulation", interval=25, run_immediately=True) %}
|
||||
{% set _ = debug("HOLODECK added npc", npc=npc) %}
|
||||
{% set _ = processed.append(call) %}
|
||||
{% set _ = agent_action("visual", "generate_character_portrait", character_name=npc.name) %}
|
||||
{% elif call.strip().startswith("remove_ai_character") %}
|
||||
{# remove npc #}
|
||||
|
||||
{% set _ = emit_status("busy", "Simulation suite removing character.", as_scene_message=True) %}
|
||||
{% set character_name = agent_action("creator", "determine_character_name", character_name=inject+" - what is the name of the character being removed?", allowed_names=scene.npc_character_names) %}
|
||||
|
||||
{% set npc = scene.get_character(character_name) %}
|
||||
|
||||
{% if npc %}
|
||||
{% set _ = debug("HOLODECK remove npc", npc=npc.name) %}
|
||||
{% set _ = agent_action("world_state", "manager", action_name="deactivate_character", character_name=npc.name) %}
|
||||
{% set _ = processed.append(call) %}
|
||||
{% endif %}
|
||||
{% elif call.strip().startswith("change_ai_character") %}
|
||||
{# change existing npc #}
|
||||
|
||||
{% set _ = emit_status("busy", "Simulation suite altering character.", as_scene_message=True) %}
|
||||
{% set character_name = agent_action("creator", "determine_character_name", character_name=inject+" - what is the name of the character receiving the changes (before the change)?", allowed_names=scene.npc_character_names) %}
|
||||
|
||||
{% set character_name_after = agent_action("creator", "determine_character_name", character_name=inject+" - what is the name of the character receiving the changes (after the changes)?") %}
|
||||
|
||||
{% set npc = scene.get_character(character_name) %}
|
||||
|
||||
{% if npc %}
|
||||
{% set _ = emit_status("busy", "Changing "+character_name+" -> "+character_name_after, as_scene_message=True) %}
|
||||
{% set _ = debug("HOLODECK transform npc", npc=npc) %}
|
||||
{% set character_attributes = agent_action("world_state", "extract_character_sheet", name=npc.name, alteration_instructions=player_message.raw)%}
|
||||
{% set _ = npc.update(base_attributes=character_attributes) %}
|
||||
{% set character_description = agent_action("creator", "determine_character_description", character=npc) %}
|
||||
{% set _ = npc.update(description=character_description) %}
|
||||
{% set _ = debug("HOLODECK transform npc", attributes=character_attributes, description=character_description) %}
|
||||
{% set _ = processed.append(call) %}
|
||||
{% if character_name_after != character_name %}
|
||||
{% set _ = npc.rename(character_name_after) %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% elif call.strip().startswith("end_simulation") %}
|
||||
{# end simulation #}
|
||||
{% set explicit_command = query_text_eval("has the player explicitly asked to end the simulation?", player_message.raw) %}
|
||||
{% if explicit_command %}
|
||||
{% set _ = emit_status("busy", "Simulation suite ending current simulation.", as_scene_message=True) %}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="progress_story", narrative_direction="The computer ends the simulation, disolving the environment and all artifical characters, erasing all memory of it and finally returning the player to the inactive simulation suite.List of artificial characters: "+(",".join(scene.npc_character_names))+". The player is also transformed back to their normal persona.", emit_message=True) %}
|
||||
{% set _ = scene.sync_restore() %}
|
||||
{% set _ = agent_action("world_state", "update_world_state", force=True) %}
|
||||
{% set simulation_reset = True %}
|
||||
{% endif %}
|
||||
{% elif "(" in call.strip() %}
|
||||
{# unknown function call, still add it to processed stack so it can be incoorporated in the narration #}
|
||||
{% set _ = processed.append(call) %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{% if processed and not simulation_reset %}
|
||||
{% set _ = game_state.set_var("instr.has_issued_instructions", "yes", commit=False) %}
|
||||
{% set _ = emit_status("busy", "Simulation suite altering environment.", as_scene_message=True) %}
|
||||
{% set update_world_state = True %}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="progress_story", narrative_direction="The computer calls the following functions:\n"+processed.join("\n")+"\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up.", emit_message=True) %}
|
||||
{% endif %}
|
||||
|
||||
{% elif not game_state.has_var("instr.simulation_started") %}
|
||||
{# no player message yet, start of scenario #}
|
||||
{% set _ = emit_status("busy", "Simulation suite powering up.", as_scene_message=True) %}
|
||||
{% set _ = game_state.set_var("instr.simulation_started", "yes", commit=False) %}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="progress_story", narrative_direction="Narrate the computer asking the user to state the nature of their desired simulation.", emit_message=False) %}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="passthrough", narration="Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction.") %}
|
||||
|
||||
{# pin to make sure characters don't try to interact with the simulation #}
|
||||
{% set _ = agent_action("world_state", "manager", action_name="save_world_entry", entry_id="sim.quarantined", text="Characters in the simulation ARE NOT AWARE OF THE COMPUTER.", meta=make_dict(), pin=True) %}
|
||||
|
||||
{% set _ = emit_status("success", "Simulation suite ready", as_scene_message=True) %}
|
||||
{% endif %}
|
||||
|
||||
{% else %}
|
||||
{# simulation ongoing #}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% if update_world_state %}
|
||||
{% set _ = emit_status("busy", "Simulation suite updating world state.", as_scene_message=True) %}
|
||||
{% set _ = agent_action("world_state", "update_world_state", force=True) %}
|
||||
{% endif %}
|
||||
|
||||
{% if not scene.npc_character_names and not simulation_reset %}
|
||||
{# no characters in the scene, see if there are any to add #}
|
||||
{% set npcs = agent_action("director", "persist_characters_from_worldstate", exclude=["computer", "user", "player", "you"]) %}
|
||||
{% for npc in npcs %}
|
||||
{% set _ = agent_action("world_state", "manager", action_name="add_detail_reinforcement", character_name=npc.name, question="Goal", instructions="Generate a goal for the character, based on the user's chosen simulation", interval=25, run_immediately=True) %}
|
||||
{% endfor %}
|
||||
{% if npcs %}
|
||||
{% set _ = agent_action("world_state", "update_world_state", force=True) %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if player_message_is_instruction %}
|
||||
{# hide player message to the computer, so its not included in the scene context #}
|
||||
{% set _ = player_message.hide() %}
|
||||
{% set _ = game_state.set_var("instr.last_processed", player_message.id, commit=False) %}
|
||||
{% set _ = emit_status("success", "Simulation suite processed instructions", as_scene_message=True) %}
|
||||
{% elif player_message and not game_state.has_var("instr.has_issued_instructions") %}
|
||||
{# simulation not started, but player message is not an instruction #}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="paraphrase", narration="Instructions to the simulation computer are only process if the computer is addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\"", emit_message=True) %}
|
||||
{% elif player_message and not scene.npc_character_names %}
|
||||
{# simulation started, player message is NOT an instruction, but there are no npcs to interact with #}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="progress_story", narrative_direction="The environment reacts to the player's actions. YOU MUST NOT ACT ON BEHALF OF THE PLAYER. YOU MUST NOT INTERACT WITH THE COMPUTER.", emit_message=True) %}
|
||||
{% endif %}
|
||||
@@ -2,4 +2,4 @@ from .agents import Agent
|
||||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.18.2"
|
||||
VERSION = "0.21.0"
|
||||
|
||||
@@ -8,4 +8,5 @@ from .narrator import NarratorAgent
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register
|
||||
from .summarize import SummarizeAgent
|
||||
from .tts import TTSAgent
|
||||
from .visual import VisualAgent
|
||||
from .world_state import WorldStateAgent
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import dataclasses
|
||||
import re
|
||||
from abc import ABC
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import pydantic
|
||||
@@ -19,6 +20,11 @@ from talemate.events import GameLoopStartEvent
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentAction",
|
||||
"AgentActionConditional",
|
||||
"AgentActionConfig",
|
||||
"AgentDetail",
|
||||
"AgentEmission",
|
||||
"set_processing",
|
||||
]
|
||||
|
||||
@@ -42,11 +48,24 @@ class AgentActionConfig(pydantic.BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentActionConditional(pydantic.BaseModel):
|
||||
attribute: str
|
||||
value: Union[int, float, str, bool, None] = None
|
||||
|
||||
|
||||
class AgentAction(pydantic.BaseModel):
|
||||
enabled: bool = True
|
||||
label: str
|
||||
description: str = ""
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
|
||||
|
||||
class AgentDetail(pydantic.BaseModel):
|
||||
value: Union[str, None] = None
|
||||
description: Union[str, None] = None
|
||||
icon: Union[str, None] = None
|
||||
color: str = "grey"
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
@@ -58,6 +77,7 @@ def set_processing(fn):
|
||||
the function fails.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
with ActiveAgent(self, fn):
|
||||
try:
|
||||
@@ -71,8 +91,6 @@ def set_processing(fn):
|
||||
# some concurrency error?
|
||||
log.error("error emitting agent status", exc=exc)
|
||||
|
||||
wrapper.__name__ = fn.__name__
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -86,6 +104,9 @@ class Agent(ABC):
|
||||
set_processing = set_processing
|
||||
requires_llm_client = True
|
||||
auto_break_repetition = False
|
||||
websocket_handler = None
|
||||
essential = True
|
||||
ready_check_error = None
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
@@ -110,13 +131,20 @@ class Agent(ABC):
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.ready:
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
return "idle" if getattr(self, "processing", 0) == 0 else "busy"
|
||||
else:
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
|
||||
if not self.ready:
|
||||
return "uninitialized"
|
||||
|
||||
if getattr(self, "processing", 0) > 0:
|
||||
return "busy"
|
||||
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg"
|
||||
|
||||
return "idle"
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
# by default, agents are enabled, an agent class that
|
||||
@@ -160,7 +188,41 @@ class Agent(ABC):
|
||||
|
||||
return config_options
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
@property
|
||||
def meta(self):
|
||||
return {
|
||||
"essential": self.essential,
|
||||
}
|
||||
|
||||
async def _handle_ready_check(self, fut: asyncio.Future):
|
||||
callback_failure = getattr(self, "on_ready_check_failure", None)
|
||||
if fut.cancelled():
|
||||
if callback_failure:
|
||||
await callback_failure()
|
||||
return
|
||||
|
||||
if fut.exception():
|
||||
exc = fut.exception()
|
||||
self.ready_check_error = exc
|
||||
log.error("agent ready check error", agent=self.agent_type, exc=exc)
|
||||
if callback_failure:
|
||||
await callback_failure(exc)
|
||||
return
|
||||
|
||||
callback = getattr(self, "on_ready_check_success", None)
|
||||
if callback:
|
||||
await callback()
|
||||
|
||||
async def ready_check(self, task: asyncio.Task = None):
|
||||
self.ready_check_error = None
|
||||
if task:
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_ready_check(fut))
|
||||
)
|
||||
return
|
||||
return True
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
if self.has_toggle and "enabled" in kwargs:
|
||||
self.is_enabled = kwargs.get("enabled", False)
|
||||
|
||||
@@ -228,27 +290,55 @@ class Agent(ABC):
|
||||
if getattr(self, "processing", None) is None:
|
||||
self.processing = 0
|
||||
|
||||
if not processing:
|
||||
if processing is False:
|
||||
self.processing -= 1
|
||||
self.processing = max(0, self.processing)
|
||||
else:
|
||||
elif processing is True:
|
||||
self.processing += 1
|
||||
|
||||
status = "busy" if self.processing > 0 else "idle"
|
||||
if not self.enabled:
|
||||
status = "disabled"
|
||||
|
||||
emit(
|
||||
"agent_status",
|
||||
message=self.verbose_name or "",
|
||||
id=self.agent_type,
|
||||
status=status,
|
||||
status=self.status,
|
||||
details=self.agent_details,
|
||||
meta=self.meta,
|
||||
data=self.config_options(agent=self),
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def _handle_background_processing(self, fut: asyncio.Future):
|
||||
try:
|
||||
if fut.cancelled():
|
||||
return
|
||||
|
||||
if fut.exception():
|
||||
log.error(
|
||||
"background processing error",
|
||||
agent=self.agent_type,
|
||||
exc=fut.exception(),
|
||||
)
|
||||
await self.emit_status()
|
||||
return
|
||||
|
||||
log.info("background processing done", agent=self.agent_type)
|
||||
finally:
|
||||
self.processing_bg -= 1
|
||||
await self.emit_status()
|
||||
|
||||
async def set_background_processing(self, task: asyncio.Task):
|
||||
log.info("set_background_processing", agent=self.agent_type)
|
||||
if not hasattr(self, "processing_bg"):
|
||||
self.processing_bg = 0
|
||||
|
||||
self.processing_bg += 1
|
||||
|
||||
await self.emit_status()
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_background_processing(fut))
|
||||
)
|
||||
|
||||
def connect(self, scene):
|
||||
self.scene = scene
|
||||
talemate.emit.async_signals.get("game_loop_start").connect(
|
||||
|
||||
@@ -13,6 +13,7 @@ active_agent = contextvars.ContextVar("active_agent", default=None)
|
||||
class ActiveAgentContext(pydantic.BaseModel):
|
||||
agent: object
|
||||
fn: Callable
|
||||
agent_stack: list = pydantic.Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -21,12 +22,23 @@ class ActiveAgentContext(pydantic.BaseModel):
|
||||
def action(self):
|
||||
return self.fn.__name__
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.agent.verbose_name}.{self.action}"
|
||||
|
||||
|
||||
class ActiveAgent:
|
||||
def __init__(self, agent, fn):
|
||||
self.agent = ActiveAgentContext(agent=agent, fn=fn)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
previous_agent = active_agent.get()
|
||||
|
||||
if previous_agent:
|
||||
self.agent.agent_stack = previous_agent.agent_stack + [str(self.agent)]
|
||||
else:
|
||||
self.agent.agent_stack = [str(self.agent)]
|
||||
|
||||
self.token = active_agent.set(self.agent)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
|
||||
@@ -78,14 +78,23 @@ class ConversationAgent(Agent):
|
||||
self.actions = {
|
||||
"generation_override": AgentAction(
|
||||
enabled=True,
|
||||
label="Generation Override",
|
||||
description="Override generation parameters",
|
||||
label="Generation Settings",
|
||||
config={
|
||||
"format": AgentActionConfig(
|
||||
type="text",
|
||||
label="Format",
|
||||
description="The format of the dialogue, as seen by the AI.",
|
||||
choices=[
|
||||
{"label": "Movie Script", "value": "movie_script"},
|
||||
{"label": "Chat (legacy)", "value": "chat"},
|
||||
],
|
||||
value="chat",
|
||||
),
|
||||
"length": AgentActionConfig(
|
||||
type="number",
|
||||
label="Generation Length (tokens)",
|
||||
description="Maximum number of tokens to generate for a conversation response.",
|
||||
value=96,
|
||||
value=128,
|
||||
min=32,
|
||||
max=512,
|
||||
step=32,
|
||||
@@ -166,6 +175,12 @@ class ConversationAgent(Agent):
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def conversation_format(self):
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["format"].value
|
||||
return "movie_script"
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
@@ -605,14 +620,20 @@ class ConversationAgent(Agent):
|
||||
|
||||
result = result.replace(" :", ":")
|
||||
|
||||
total_result = total_result.split("#")[0]
|
||||
total_result = total_result.split("#")[0].strip()
|
||||
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
|
||||
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
# Remove "{character.name}:" - all occurences
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
@@ -660,4 +681,4 @@ class ConversationAgent(Agent):
|
||||
):
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += ["["]
|
||||
prompt_param["extra_stopping_strings"] += ["#"]
|
||||
|
||||
@@ -9,13 +9,13 @@ from talemate.agents.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
from .assistant import AssistantMixin
|
||||
from .character import CharacterCreatorMixin
|
||||
from .scenario import ScenarioCreatorMixin
|
||||
|
||||
|
||||
@register()
|
||||
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, AssistantMixin, Agent):
|
||||
"""
|
||||
Creates characters and scenarios and other fun stuff!
|
||||
"""
|
||||
|
||||
95
src/talemate/agents/creator/assistant.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import pydantic
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character, Scene
|
||||
|
||||
|
||||
class ContentGenerationContext(pydantic.BaseModel):
|
||||
"""
|
||||
A context for generating content.
|
||||
"""
|
||||
|
||||
context: str
|
||||
instructions: str
|
||||
length: int
|
||||
character: Union[str, None] = None
|
||||
original: Union[str, None] = None
|
||||
|
||||
@property
|
||||
def computed_context(self) -> (str, str):
|
||||
typ, context = self.context.split(":", 1)
|
||||
return typ, context
|
||||
|
||||
|
||||
class AssistantMixin:
|
||||
"""
|
||||
Creator mixin that allows quick contextual generation of content.
|
||||
"""
|
||||
|
||||
async def contextual_generate_from_args(
|
||||
self,
|
||||
context: str,
|
||||
instructions: str,
|
||||
length: int = 100,
|
||||
character: Union[str, None] = None,
|
||||
original: Union[str, None] = None,
|
||||
):
|
||||
"""
|
||||
Request content from the assistant.
|
||||
"""
|
||||
|
||||
generation_context = ContentGenerationContext(
|
||||
context=context,
|
||||
instructions=instructions,
|
||||
length=length,
|
||||
character=character,
|
||||
original=original,
|
||||
)
|
||||
|
||||
return await self.contextual_generate(generation_context)
|
||||
|
||||
@set_processing
|
||||
async def contextual_generate(
|
||||
self,
|
||||
generation_context: ContentGenerationContext,
|
||||
):
|
||||
"""
|
||||
Request content from the assistant.
|
||||
"""
|
||||
|
||||
context_typ, context_name = generation_context.computed_context
|
||||
|
||||
if generation_context.length < 100:
|
||||
kind = "create_short"
|
||||
elif generation_context.length < 500:
|
||||
kind = "create_concise"
|
||||
else:
|
||||
kind = "create"
|
||||
|
||||
content = await Prompt.request(
|
||||
f"creator.contextual-generate",
|
||||
self.client,
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"generation_context": generation_context,
|
||||
"context_typ": context_typ,
|
||||
"context_name": context_name,
|
||||
"character": (
|
||||
self.scene.get_character(generation_context.character)
|
||||
if generation_context.character
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
return content.strip()
|
||||
@@ -208,6 +208,25 @@ class CharacterCreatorMixin:
|
||||
)
|
||||
return attributes
|
||||
|
||||
@set_processing
|
||||
async def determine_character_name(
|
||||
self,
|
||||
character_name: str,
|
||||
allowed_names: list[str] = None,
|
||||
) -> str:
|
||||
name = await Prompt.request(
|
||||
f"creator.determine-character-name",
|
||||
self.client,
|
||||
"analyze_freeform_short",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character_name": character_name,
|
||||
"allowed_names": allowed_names or [],
|
||||
},
|
||||
)
|
||||
return name.split('"', 1)[0].strip().strip(".").strip()
|
||||
|
||||
@set_processing
|
||||
async def determine_character_description(
|
||||
self, character: Character, text: str = ""
|
||||
|
||||
@@ -7,7 +7,6 @@ from talemate.prompts import Prompt
|
||||
|
||||
|
||||
class ScenarioCreatorMixin:
|
||||
|
||||
"""
|
||||
Adds scenario creation functionality to the creator agent
|
||||
"""
|
||||
|
||||
34
src/talemate/agents/custom/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger("talemate.agents.custom")
|
||||
|
||||
# import every submodule in this directory
|
||||
#
|
||||
# each directory in this directory is a submodule
|
||||
|
||||
# get the current directory
|
||||
current_directory = os.path.dirname(__file__)
|
||||
|
||||
# get all subdirectories
|
||||
subdirectories = [
|
||||
os.path.join(current_directory, name)
|
||||
for name in os.listdir(current_directory)
|
||||
if os.path.isdir(os.path.join(current_directory, name))
|
||||
]
|
||||
|
||||
# import every submodule
|
||||
|
||||
for subdirectory in subdirectories:
|
||||
# get the name of the submodule
|
||||
submodule_name = os.path.basename(subdirectory)
|
||||
|
||||
if submodule_name.startswith("__"):
|
||||
continue
|
||||
|
||||
log.info("activating custom agent", module=submodule_name)
|
||||
|
||||
# import the submodule
|
||||
importlib.import_module(f".{submodule_name}", __package__)
|
||||
@@ -0,0 +1,5 @@
|
||||
Each agent should be in its own subdirectory.
|
||||
|
||||
The subdirectory itself must be a valid python module.
|
||||
|
||||
Check out docs/dev/agents/example/test for a very simplistic custom agent example.
|
||||
@@ -182,7 +182,10 @@ class DirectorAgent(Agent):
|
||||
|
||||
# no character, see if there are NPC characters at all
|
||||
# if not we always want to direct narration
|
||||
always_direct = not self.scene.npc_character_names
|
||||
always_direct = (
|
||||
not self.scene.npc_character_names
|
||||
or self.scene.game_state.ops.always_direct
|
||||
)
|
||||
|
||||
next_direct = self.next_direct_scene
|
||||
|
||||
@@ -253,6 +256,34 @@ class DirectorAgent(Agent):
|
||||
# run scene instructions
|
||||
self.scene.game_state.scene_instructions
|
||||
|
||||
@set_processing
|
||||
async def persist_characters_from_worldstate(
|
||||
self, exclude: list[str] = None
|
||||
) -> List[Character]:
|
||||
log.warning(
|
||||
"persist_characters_from_worldstate",
|
||||
world_state_characters=self.scene.world_state.characters,
|
||||
scene_characters=self.scene.character_names,
|
||||
)
|
||||
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
if character_name in self.scene.character_names:
|
||||
continue
|
||||
|
||||
character = await self.persist_character(name=character_name)
|
||||
|
||||
created_characters.append(character)
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
return created_characters
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
self,
|
||||
@@ -262,7 +293,10 @@ class DirectorAgent(Agent):
|
||||
):
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
|
||||
self.scene.log.debug("persist_character", name=name)
|
||||
name = await creator.determine_character_name(name)
|
||||
self.scene.log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
character = self.scene.Character(name=name)
|
||||
character.color = random.choice(
|
||||
|
||||
@@ -40,11 +40,6 @@ class EditorAgent(Agent):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"edit_dialogue": AgentAction(
|
||||
enabled=False,
|
||||
label="Edit dialogue",
|
||||
description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
"fix_exposition": AgentAction(
|
||||
enabled=True,
|
||||
label="Fix exposition",
|
||||
@@ -100,8 +95,6 @@ class EditorAgent(Agent):
|
||||
for text in emission.generation:
|
||||
edit = await self.add_detail(text, emission.character)
|
||||
|
||||
edit = await self.edit_conversation(edit, emission.character)
|
||||
|
||||
edit = await self.fix_exposition(edit, emission.character)
|
||||
|
||||
edited.append(edit)
|
||||
@@ -126,35 +119,6 @@ class EditorAgent(Agent):
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
@set_processing
|
||||
async def edit_conversation(self, content: str, character: Character):
|
||||
"""
|
||||
Edits a conversation
|
||||
"""
|
||||
|
||||
if not self.actions["edit_dialogue"].enabled:
|
||||
return content
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.edit-dialogue",
|
||||
self.client,
|
||||
"edit_dialogue",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
response = response.split("[end]")[0]
|
||||
|
||||
response = util.replace_exposition_markers(response)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def fix_exposition(self, content: str, character: Character):
|
||||
"""
|
||||
@@ -169,7 +133,7 @@ class EditorAgent(Agent):
|
||||
content = util.strip_partial_sentences(content)
|
||||
character_prefix = f"{character.name}: "
|
||||
message = content.split(character_prefix)[1]
|
||||
content = f"{character_prefix}*{message.strip('*')}*"
|
||||
content = f'{character_prefix}"{message.strip()}"'
|
||||
return content
|
||||
elif '"' in content:
|
||||
# silly hack to clean up some LLMs that always start with a quote
|
||||
|
||||
@@ -30,7 +30,7 @@ if not chromadb:
|
||||
log.info("ChromaDB not found, disabling Chroma agent")
|
||||
|
||||
|
||||
from .base import Agent
|
||||
from .base import Agent, AgentDetail
|
||||
|
||||
|
||||
class MemoryDocument(str):
|
||||
@@ -368,8 +368,30 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value="ChromaDB",
|
||||
description="The backend to use for long-term memory",
|
||||
).model_dump(),
|
||||
"embeddings": AgentDetail(
|
||||
icon="mdi-cube-unfolded",
|
||||
value=self.embeddings,
|
||||
description="The embeddings model.",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return "No OpenAI API key set"
|
||||
# return "No OpenAI API key set"
|
||||
details["error"] = {
|
||||
"icon": "mdi-alert",
|
||||
"value": "No OpenAI API key set",
|
||||
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
|
||||
"color": "error",
|
||||
}
|
||||
|
||||
return details
|
||||
|
||||
return f"ChromaDB: {self.embeddings}"
|
||||
|
||||
@@ -425,7 +447,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
def make_collection_name(self, scene):
|
||||
if self.USE_OPENAI:
|
||||
suffix = "-openai"
|
||||
model_name = self.config.get("chromadb").get(
|
||||
"openai_model", "text-embedding-3-small"
|
||||
)
|
||||
if model_name == "text-embedding-ada-002":
|
||||
suffix = "-openai"
|
||||
else:
|
||||
suffix = f"-openai-{model_name}"
|
||||
elif self.USE_INSTRUCTOR:
|
||||
suffix = "-instructor"
|
||||
model = self.config.get("chromadb").get(
|
||||
@@ -472,12 +500,19 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
"You must provide an the openai ai key in the config if you want to use it for chromadb embeddings"
|
||||
)
|
||||
|
||||
model_name = self.config.get("chromadb").get(
|
||||
"openai_model", "text-embedding-3-small"
|
||||
)
|
||||
|
||||
log.info(
|
||||
"crhomadb", status="using openai", openai_key=openai_key[:5] + "..."
|
||||
"crhomadb",
|
||||
status="using openai",
|
||||
openai_key=openai_key[:5] + "...",
|
||||
model=model_name,
|
||||
)
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=openai_key,
|
||||
model_name="text-embedding-ada-002",
|
||||
model_name=model_name,
|
||||
)
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=openai_ef
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import random
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
@@ -40,7 +41,8 @@ def set_processing(fn):
|
||||
"""
|
||||
|
||||
@_set_processing
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
@wraps(fn)
|
||||
async def narration_wrapper(self, *args, **kwargs):
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission = NarratorAgentEmission(
|
||||
agent=self,
|
||||
@@ -49,13 +51,11 @@ def set_processing(fn):
|
||||
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
|
||||
return emission.generation[0]
|
||||
|
||||
wrapper.__name__ = fn.__name__
|
||||
return wrapper
|
||||
return narration_wrapper
|
||||
|
||||
|
||||
@register()
|
||||
class NarratorAgent(Agent):
|
||||
|
||||
"""
|
||||
Handles narration of the story
|
||||
"""
|
||||
@@ -524,21 +524,98 @@ class NarratorAgent(Agent):
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def paraphrase(self, narration: str):
|
||||
"""
|
||||
Paraphrase a narration
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.paraphrase",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars={
|
||||
"text": narration,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
log.info("paraphrase", narration=narration, response=response)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
async def passthrough(self, narration: str) -> str:
|
||||
"""
|
||||
Pass through narration message as is
|
||||
"""
|
||||
narration = narration.replace("*", "")
|
||||
narration = f"*{narration}*"
|
||||
narration = util.ensure_dialog_format(narration)
|
||||
return narration
|
||||
|
||||
def action_to_source(
|
||||
self,
|
||||
action_name: str,
|
||||
parameters: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a source string for a given action and parameters
|
||||
|
||||
The source string is used to identify the source of a NarratorMessage
|
||||
and will also help regenerate the action and parameters from the source string
|
||||
later on
|
||||
"""
|
||||
|
||||
args = []
|
||||
|
||||
if action_name == "paraphrase":
|
||||
args.append(parameters.get("narration"))
|
||||
elif action_name == "narrate_character_entry":
|
||||
args.append(parameters.get("character").name)
|
||||
# args.append(parameters.get("direction"))
|
||||
elif action_name == "narrate_character_exit":
|
||||
args.append(parameters.get("character").name)
|
||||
# args.append(parameters.get("direction"))
|
||||
elif action_name == "narrate_character":
|
||||
args.append(parameters.get("character").name)
|
||||
elif action_name == "narrate_query":
|
||||
args.append(parameters.get("query"))
|
||||
elif action_name == "narrate_time_passage":
|
||||
args.append(parameters.get("duration"))
|
||||
args.append(parameters.get("time_passed"))
|
||||
args.append(parameters.get("narrative"))
|
||||
elif action_name == "progress_story":
|
||||
args.append(parameters.get("narrative_direction"))
|
||||
elif action_name == "narrate_after_dialogue":
|
||||
args.append(parameters.get("character"))
|
||||
|
||||
arg_str = ";".join(args) if args else ""
|
||||
|
||||
return f"{action_name}:{arg_str}".rstrip(":")
|
||||
|
||||
async def action_to_narration(
|
||||
self,
|
||||
action_name: str,
|
||||
*args,
|
||||
emit_message: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# calls self[action_name] and returns the result as a NarratorMessage
|
||||
# that is pushed to the history
|
||||
|
||||
fn = getattr(self, action_name)
|
||||
narration = await fn(*args, **kwargs)
|
||||
narrator_message = NarratorMessage(
|
||||
narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":")
|
||||
)
|
||||
narration = await fn(**kwargs)
|
||||
source = self.action_to_source(action_name, kwargs)
|
||||
|
||||
narrator_message = NarratorMessage(narration, source=source)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
if emit_message:
|
||||
emit("narrator", narrator_message)
|
||||
|
||||
return narrator_message
|
||||
|
||||
# LLM client related methods. These are called during or after the client
|
||||
|
||||
@@ -262,9 +262,11 @@ class SummarizeAgent(Agent):
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"summarization_method": self.actions["archive"].config["method"].value
|
||||
if method is None
|
||||
else method,
|
||||
"summarization_method": (
|
||||
self.actions["archive"].config["method"].value
|
||||
if method is None
|
||||
else method
|
||||
),
|
||||
"extra_context": extra_context or "",
|
||||
"extra_instructions": extra_instructions or "",
|
||||
},
|
||||
|
||||
@@ -15,6 +15,7 @@ import nltk
|
||||
import pydantic
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import talemate.config as config
|
||||
import talemate.emit.async_signals
|
||||
@@ -24,7 +25,14 @@ from talemate.emit.signals import handlers
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import CharacterMessage, NarratorMessage
|
||||
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from .registry import register
|
||||
|
||||
try:
|
||||
@@ -109,7 +117,6 @@ class VoiceLibrary(pydantic.BaseModel):
|
||||
|
||||
@register()
|
||||
class TTSAgent(Agent):
|
||||
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
@@ -117,6 +124,7 @@ class TTSAgent(Agent):
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
essential = False
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
@@ -135,11 +143,12 @@ class TTSAgent(Agent):
|
||||
|
||||
self.voices = {
|
||||
"elevenlabs": VoiceLibrary(api="elevenlabs"),
|
||||
"coqui": VoiceLibrary(api="coqui"),
|
||||
"tts": VoiceLibrary(api="tts"),
|
||||
"openai": VoiceLibrary(api="openai"),
|
||||
}
|
||||
self.config = config.load_config()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
self.preselect_voice = None
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
@@ -149,10 +158,9 @@ class TTSAgent(Agent):
|
||||
"api": AgentActionConfig(
|
||||
type="text",
|
||||
choices=[
|
||||
# TODO at local TTS support
|
||||
{"value": "tts", "label": "TTS (Local)"},
|
||||
{"value": "elevenlabs", "label": "Eleven Labs"},
|
||||
{"value": "coqui", "label": "Coqui Studio"},
|
||||
{"value": "openai", "label": "OpenAI"},
|
||||
],
|
||||
value="tts",
|
||||
label="API",
|
||||
@@ -192,6 +200,25 @@ class TTSAgent(Agent):
|
||||
),
|
||||
},
|
||||
),
|
||||
"openai": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.api", value="openai"
|
||||
),
|
||||
label="OpenAI Settings",
|
||||
config={
|
||||
"model": AgentActionConfig(
|
||||
type="text",
|
||||
value="tts-1",
|
||||
choices=[
|
||||
{"value": "tts-1", "label": "TTS 1"},
|
||||
{"value": "tts-1-hd", "label": "TTS 1 HD"},
|
||||
],
|
||||
label="Model",
|
||||
description="TTS model to use",
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
self.actions["_config"].model_dump()
|
||||
@@ -230,27 +257,45 @@ class TTSAgent(Agent):
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
suffix = ""
|
||||
|
||||
if not self.ready:
|
||||
suffix = f" - {self.not_ready_reason}"
|
||||
else:
|
||||
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
|
||||
details = {
|
||||
"api": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.api_label,
|
||||
description="The backend to use for TTS",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
api = self.api
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api_label = api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
api_label = choice["label"]
|
||||
break
|
||||
if self.ready and self.enabled:
|
||||
details["voice"] = AgentDetail(
|
||||
icon="mdi-account-voice",
|
||||
value=self.voice_id_to_label(self.default_voice_id) or "",
|
||||
description="The voice to use for TTS",
|
||||
color="info",
|
||||
).model_dump()
|
||||
elif self.enabled:
|
||||
details["error"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=self.not_ready_reason,
|
||||
description=self.not_ready_reason,
|
||||
color="error",
|
||||
).model_dump()
|
||||
|
||||
return f"{api_label}{suffix}"
|
||||
return details
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self.actions["_config"].config["api"].value
|
||||
|
||||
@property
|
||||
def api_label(self):
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api = self.api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
return choice["label"]
|
||||
return api
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
api = self.api
|
||||
@@ -278,6 +323,8 @@ class TTSAgent(Agent):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
if getattr(self, "processing_bg", 0) > 0:
|
||||
return "busy_bg" if not getattr(self, "processing", False) else "busy"
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
if self.requires_token and not self.token:
|
||||
return "error"
|
||||
@@ -295,7 +342,11 @@ class TTSAgent(Agent):
|
||||
|
||||
return 250
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
try:
|
||||
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
|
||||
except KeyError:
|
||||
@@ -304,10 +355,22 @@ class TTSAgent(Agent):
|
||||
api_changed = api != self.api
|
||||
|
||||
log.debug(
|
||||
"apply_config", api=api, api_changed=api != self.api, current_api=self.api
|
||||
"apply_config",
|
||||
api=api,
|
||||
api_changed=api != self.api,
|
||||
current_api=self.api,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
super().apply_config(*args, **kwargs)
|
||||
try:
|
||||
self.preselect_voice = kwargs["actions"]["_config"]["config"]["voice_id"][
|
||||
"value"
|
||||
]
|
||||
except KeyError:
|
||||
self.preselect_voice = self.default_voice_id
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
if api_changed:
|
||||
try:
|
||||
@@ -400,6 +463,11 @@ class TTSAgent(Agent):
|
||||
library.voices = await list_fn()
|
||||
library.last_synced = time.time()
|
||||
|
||||
if self.preselect_voice:
|
||||
if self.voice(self.preselect_voice):
|
||||
self.actions["_config"].config["voice_id"].value = self.preselect_voice
|
||||
self.preselect_voice = None
|
||||
|
||||
# if the current voice cannot be found, reset it
|
||||
if not self.voice(self.default_voice_id):
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
@@ -425,9 +493,10 @@ class TTSAgent(Agent):
|
||||
|
||||
# Start generating audio chunks in the background
|
||||
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
|
||||
await self.set_background_processing(generation_task)
|
||||
|
||||
# Wait for both tasks to complete
|
||||
await asyncio.gather(generation_task)
|
||||
# await asyncio.gather(generation_task)
|
||||
|
||||
async def generate_chunks(self, generate_fn, chunks):
|
||||
for chunk in chunks:
|
||||
@@ -552,96 +621,32 @@ class TTSAgent(Agent):
|
||||
|
||||
return voices
|
||||
|
||||
# COQUI STUDIO
|
||||
# OPENAI
|
||||
|
||||
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
async def _generate_openai(self, text: str, chunk_size: int = 1024):
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = "https://app.coqui.ai/api/v2/samples/xtts/render/"
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
data = {
|
||||
"voice_id": self.default_voice_id,
|
||||
"text": text,
|
||||
"language": "en", # Assuming English language for simplicity; this could be parameterized
|
||||
}
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
# Make the POST request to Coqui API
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
if response.status_code in [200, 201]:
|
||||
# Parse the JSON response to get the audio URL
|
||||
response_data = response.json()
|
||||
audio_url = response_data.get("audio_url")
|
||||
if audio_url:
|
||||
# Make a GET request to download the audio file
|
||||
audio_response = await client.get(audio_url)
|
||||
if audio_response.status_code == 200:
|
||||
# delete the sample from Coqui Studio
|
||||
# await self._cleanup_coqui(response_data.get('id'))
|
||||
return audio_response.content
|
||||
else:
|
||||
log.error(f"Error downloading audio: {audio_response.text}")
|
||||
else:
|
||||
log.error("No audio URL in response")
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
model = self.actions["openai"].config["model"].value
|
||||
|
||||
async def _cleanup_coqui(self, sample_id: str):
|
||||
api_key = self.token
|
||||
if not api_key or not sample_id:
|
||||
return
|
||||
response = await client.audio.speech.create(
|
||||
model=model, voice=self.default_voice_id, input=text
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Make the DELETE request to Coqui API
|
||||
response = await client.delete(url, headers=headers)
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
|
||||
if response.status_code == 204:
|
||||
log.info(f"Successfully deleted sample with ID: {sample_id}")
|
||||
else:
|
||||
log.error(
|
||||
f"Error deleting sample with ID: {sample_id}: {response.text}"
|
||||
)
|
||||
|
||||
async def _list_voices_coqui(self) -> dict[str, str]:
|
||||
url_speakers = "https://app.coqui.ai/api/v2/speakers"
|
||||
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
|
||||
|
||||
voices = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {"Authorization": f"Bearer {self.token}"}
|
||||
response = await client.get(
|
||||
url_speakers, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
speakers = response.json()["result"]
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=speaker["id"], label=speaker["name"])
|
||||
for speaker in speakers
|
||||
]
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
url_custom_voices, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
custom_voices = response.json()["result"]
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=voice["id"], label=voice["name"])
|
||||
for voice in custom_voices
|
||||
]
|
||||
)
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
async def _list_voices_openai(self) -> dict[str, str]:
|
||||
return [
|
||||
Voice(value="alloy", label="Alloy"),
|
||||
Voice(value="echo", label="Echo"),
|
||||
Voice(value="fable", label="Fable"),
|
||||
Voice(value="onyx", label="Onyx"),
|
||||
Voice(value="nova", label="Nova"),
|
||||
Voice(value="shimmer", label="Shimmer"),
|
||||
]
|
||||
|
||||
452
src/talemate/agents/visual/__init__.py
Normal file
@@ -0,0 +1,452 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.agents.visual.automatic1111
|
||||
import talemate.agents.visual.comfyui
|
||||
import talemate.agents.visual.openai_image
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.agents.registry import register
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers as signal_handlers
|
||||
from talemate.prompts.base import Prompt
|
||||
|
||||
from .commands import * # noqa
|
||||
from .context import VIS_TYPES, VisualContext, visual_context
|
||||
from .handlers import HANDLERS
|
||||
from .schema import RESOLUTION_MAP, RenderSettings
|
||||
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
|
||||
from .websocket_handler import VisualWebsocketHandler
|
||||
|
||||
__all__ = [
|
||||
"VisualAgent",
|
||||
]
|
||||
|
||||
BACKENDS = [
|
||||
{"value": mixin_backend, "label": mixin["label"]}
|
||||
for mixin_backend, mixin in HANDLERS.items()
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual")
|
||||
|
||||
|
||||
class VisualBase(Agent):
|
||||
"""
|
||||
The visual agent
|
||||
"""
|
||||
|
||||
agent_type = "visual"
|
||||
verbose_name = "Visualizer"
|
||||
essential = False
|
||||
websocket_handler = VisualWebsocketHandler
|
||||
|
||||
ACTIONS = {}
|
||||
|
||||
def __init__(self, client: ClientBase, *kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = False
|
||||
self.backend_ready = False
|
||||
self.initialized = False
|
||||
self.config = load_config()
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="Visual agent configuration",
|
||||
config={
|
||||
"backend": AgentActionConfig(
|
||||
type="text",
|
||||
choices=BACKENDS,
|
||||
value="automatic1111",
|
||||
label="Backend",
|
||||
description="The backend to use for visual processing",
|
||||
),
|
||||
"default_style": AgentActionConfig(
|
||||
type="text",
|
||||
value="ink_illustration",
|
||||
choices=MAJOR_STYLES,
|
||||
label="Default Style",
|
||||
description="The default style to use for visual processing",
|
||||
),
|
||||
},
|
||||
),
|
||||
"automatic_generation": AgentAction(
|
||||
enabled=False,
|
||||
label="Automatic Generation",
|
||||
description="Allow automatic generation of visual content",
|
||||
),
|
||||
"process_in_background": AgentAction(
|
||||
enabled=True,
|
||||
label="Process in Background",
|
||||
description="Process renders in the background",
|
||||
),
|
||||
}
|
||||
|
||||
for action_name, action in self.ACTIONS.items():
|
||||
self.actions[action_name] = action
|
||||
|
||||
signal_handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self.actions["_config"].config["backend"].value
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
key = self.actions["_config"].config["backend"].value
|
||||
|
||||
for backend in BACKENDS:
|
||||
if backend["value"] == key:
|
||||
return backend["label"]
|
||||
|
||||
@property
|
||||
def default_style(self):
|
||||
return STYLE_MAP.get(
|
||||
self.actions["_config"].config["default_style"].value, Style()
|
||||
)
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
return self.backend_ready
|
||||
|
||||
@property
|
||||
def api_url(self):
|
||||
try:
|
||||
return self.actions[self.backend].config["api_url"].value
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
value=self.backend_name,
|
||||
description="The backend to use for visual processing",
|
||||
).model_dump(),
|
||||
"client": AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.client.name if self.client else None,
|
||||
description="The client to use for prompt generation",
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
if not self.ready and self.enabled:
|
||||
details["status"] = AgentDetail(
|
||||
icon="mdi-alert",
|
||||
value=f"{self.backend_name} not ready",
|
||||
color="error",
|
||||
description=self.ready_check_error
|
||||
or f"{self.backend_name} is not ready for processing",
|
||||
).model_dump()
|
||||
|
||||
return details
|
||||
|
||||
@property
|
||||
def process_in_background(self):
|
||||
return self.actions["process_in_background"].enabled
|
||||
|
||||
@property
|
||||
def allow_automatic_generation(self):
|
||||
return self.actions["automatic_generation"].enabled
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
asyncio.create_task(self.emit_status())
|
||||
|
||||
async def on_ready_check_success(self):
|
||||
prev_ready = self.backend_ready
|
||||
self.backend_ready = True
|
||||
if not prev_ready:
|
||||
await self.emit_status()
|
||||
|
||||
async def on_ready_check_failure(self, error):
|
||||
prev_ready = self.backend_ready
|
||||
self.backend_ready = False
|
||||
self.ready_check_error = str(error)
|
||||
if prev_ready:
|
||||
await self.emit_status()
|
||||
|
||||
async def ready_check(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
backend = self.backend
|
||||
fn = getattr(self, f"{backend.lower()}_ready", None)
|
||||
task = asyncio.create_task(fn())
|
||||
await super().ready_check(task)
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
|
||||
try:
|
||||
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
|
||||
except KeyError:
|
||||
backend = self.backend
|
||||
|
||||
backend_changed = backend != self.backend
|
||||
|
||||
if backend_changed:
|
||||
self.backend_ready = False
|
||||
|
||||
log.info(
|
||||
"apply_config",
|
||||
backend=backend,
|
||||
backend_changed=backend_changed,
|
||||
old_backend=self.backend,
|
||||
)
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
backend_fn = getattr(self, f"{self.backend.lower()}_apply_config", None)
|
||||
if backend_fn:
|
||||
task = asyncio.create_task(
|
||||
backend_fn(backend_changed=backend_changed, *args, **kwargs)
|
||||
)
|
||||
await self.set_background_processing(task)
|
||||
|
||||
if not self.backend_ready:
|
||||
await self.ready_check()
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def resolution_from_format(self, format: str, model_type: str = "sdxl"):
|
||||
if model_type not in RESOLUTION_MAP:
|
||||
raise ValueError(f"Model type {model_type} not found in resolution map")
|
||||
return RESOLUTION_MAP[model_type].get(
|
||||
format, RESOLUTION_MAP[model_type]["portrait"]
|
||||
)
|
||||
|
||||
def prepare_prompt(self, prompt: str, styles: list[Style] = None) -> Style:
|
||||
|
||||
prompt_style = Style()
|
||||
prompt_style.load(prompt)
|
||||
|
||||
if styles:
|
||||
prompt_style.prepend(*styles)
|
||||
|
||||
return prompt_style
|
||||
|
||||
def vis_type_styles(self, vis_type: str):
|
||||
if vis_type == VIS_TYPES.CHARACTER:
|
||||
portrait_style = STYLE_MAP["character_portrait"].copy()
|
||||
return portrait_style
|
||||
elif vis_type == VIS_TYPES.ENVIRONMENT:
|
||||
environment_style = STYLE_MAP["environment"].copy()
|
||||
return environment_style
|
||||
return Style()
|
||||
|
||||
async def apply_image(self, image: str):
|
||||
context = visual_context.get()
|
||||
|
||||
log.debug("apply_image", image=image[:100], context=context)
|
||||
|
||||
if context.vis_type == VIS_TYPES.CHARACTER:
|
||||
await self.apply_image_character(image, context.character_name)
|
||||
|
||||
async def apply_image_character(self, image: str, character_name: str):
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
if not character:
|
||||
log.error("character not found", character_name=character_name)
|
||||
return
|
||||
|
||||
if character.cover_image:
|
||||
log.info("character cover image already set", character_name=character_name)
|
||||
return
|
||||
|
||||
asset = self.scene.assets.add_asset_from_image_data(
|
||||
f"data:image/png;base64,{image}"
|
||||
)
|
||||
character.cover_image = asset.id
|
||||
self.scene.assets.cover_image = asset.id
|
||||
self.scene.emit_status()
|
||||
|
||||
async def emit_image(self, image: str):
|
||||
context = visual_context.get()
|
||||
await self.apply_image(image)
|
||||
emit(
|
||||
"image_generated",
|
||||
websocket_passthrough=True,
|
||||
data={
|
||||
"base64": image,
|
||||
"context": context.model_dump() if context else None,
|
||||
},
|
||||
)
|
||||
|
||||
@set_processing
|
||||
async def generate(
|
||||
self, format: str = "portrait", prompt: str = None, automatic: bool = False
|
||||
):
|
||||
|
||||
context = visual_context.get()
|
||||
|
||||
if not self.enabled:
|
||||
log.warning("generate", skipped="Visual agent not enabled")
|
||||
return
|
||||
|
||||
if automatic and not self.allow_automatic_generation:
|
||||
log.warning(
|
||||
"generate",
|
||||
skipped="Automatic generation disabled",
|
||||
prompt=prompt,
|
||||
format=format,
|
||||
context=context,
|
||||
)
|
||||
return
|
||||
|
||||
if not context and not prompt:
|
||||
log.error("generate", error="No context or prompt provided")
|
||||
return
|
||||
|
||||
# Handle prompt generation based on context
|
||||
|
||||
if not prompt and context.prompt:
|
||||
prompt = context.prompt
|
||||
|
||||
if context.vis_type == VIS_TYPES.ENVIRONMENT and not prompt:
|
||||
prompt = await self.generate_environment_prompt(
|
||||
instructions=context.instructions
|
||||
)
|
||||
elif context.vis_type == VIS_TYPES.CHARACTER and not prompt:
|
||||
prompt = await self.generate_character_prompt(
|
||||
context.character_name, instructions=context.instructions
|
||||
)
|
||||
else:
|
||||
prompt = prompt or context.prompt
|
||||
|
||||
initial_prompt = prompt
|
||||
|
||||
# Augment the prompt with styles based on context
|
||||
|
||||
thematic_style = self.default_style
|
||||
vis_type_styles = self.vis_type_styles(context.vis_type)
|
||||
prompt = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
|
||||
|
||||
if not prompt:
|
||||
log.error(
|
||||
"generate", error="No prompt provided and no context to generate from"
|
||||
)
|
||||
return
|
||||
|
||||
context.prompt = initial_prompt
|
||||
context.prepared_prompt = str(prompt)
|
||||
|
||||
# Handle format (can either come from context or be passed in)
|
||||
|
||||
if not format and context.format:
|
||||
format = context.format
|
||||
elif not format:
|
||||
format = "portrait"
|
||||
|
||||
context.format = format
|
||||
|
||||
# Call the backend specific generate function
|
||||
|
||||
backend = self.backend
|
||||
fn = f"{backend.lower()}_generate"
|
||||
|
||||
log.info(
|
||||
"generate", backend=backend, prompt=prompt, format=format, context=context
|
||||
)
|
||||
|
||||
if not hasattr(self, fn):
|
||||
log.error("generate", error=f"Backend {backend} does not support generate")
|
||||
|
||||
# add the function call to the asyncio task queue
|
||||
|
||||
if self.process_in_background:
|
||||
task = asyncio.create_task(getattr(self, fn)(prompt=prompt, format=format))
|
||||
await self.set_background_processing(task)
|
||||
else:
|
||||
await getattr(self, fn)(prompt=prompt, format=format)
|
||||
|
||||
@set_processing
|
||||
async def generate_environment_prompt(self, instructions: str = None):
|
||||
|
||||
response = await Prompt.request(
|
||||
"visual.generate-environment-prompt",
|
||||
self.client,
|
||||
"visualize",
|
||||
{
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
@set_processing
|
||||
async def generate_character_prompt(
|
||||
self, character_name: str, instructions: str = None
|
||||
):
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
response = await Prompt.request(
|
||||
"visual.generate-character-prompt",
|
||||
self.client,
|
||||
"visualize",
|
||||
{
|
||||
"scene": self.scene,
|
||||
"character_name": character_name,
|
||||
"character": character,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"instructions": instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
async def generate_environment_background(self, instructions: str = None):
|
||||
with VisualContext(vis_type=VIS_TYPES.ENVIRONMENT, instructions=instructions):
|
||||
await self.generate(format="landscape")
|
||||
|
||||
async def generate_character_portrait(
|
||||
self,
|
||||
character_name: str,
|
||||
instructions: str = None,
|
||||
):
|
||||
with VisualContext(
|
||||
vis_type=VIS_TYPES.CHARACTER,
|
||||
character_name=character_name,
|
||||
instructions=instructions,
|
||||
):
|
||||
await self.generate(format="portrait")
|
||||
|
||||
|
||||
# apply mixins to the agent (from HANDLERS dict[str, cls])
|
||||
|
||||
for mixin_backend, mixin in HANDLERS.items():
|
||||
mixin_cls = mixin["cls"]
|
||||
VisualBase = type("VisualAgent", (mixin_cls, VisualBase), {})
|
||||
|
||||
extend_actions = getattr(mixin_cls, "EXTEND_ACTIONS", {})
|
||||
|
||||
for action_name, action in extend_actions.items():
|
||||
VisualBase.ACTIONS[action_name] = action
|
||||
|
||||
|
||||
@register()
|
||||
class VisualAgent(VisualBase):
|
||||
pass
|
||||
117
src/talemate/agents/visual/automatic1111.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.automatic1111")
|
||||
|
||||
|
||||
@register(backend_name="automatic1111", label="AUTOMATIC1111")
|
||||
class Automatic1111Mixin:
|
||||
|
||||
automatic1111_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"automatic1111": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="automatic1111"
|
||||
),
|
||||
label="Automatic1111 Settings",
|
||||
description="Setting overrides for the automatic1111 backend",
|
||||
config={
|
||||
"api_url": AgentActionConfig(
|
||||
type="text",
|
||||
value="http://localhost:7860",
|
||||
label="API URL",
|
||||
description="The URL of the backend API",
|
||||
),
|
||||
"steps": AgentActionConfig(
|
||||
type="number",
|
||||
value=40,
|
||||
label="Steps",
|
||||
min=5,
|
||||
max=150,
|
||||
step=1,
|
||||
description="number of render steps",
|
||||
),
|
||||
"model_type": AgentActionConfig(
|
||||
type="text",
|
||||
value="sdxl",
|
||||
choices=[
|
||||
{"value": "sdxl", "label": "SDXL"},
|
||||
{"value": "sd15", "label": "SD1.5"},
|
||||
],
|
||||
label="Model Type",
|
||||
description="Right now just differentiates between sdxl and sd15 - affect generation resolution",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def automatic1111_render_settings(self):
|
||||
if self.actions["automatic1111"].enabled:
|
||||
return RenderSettings(
|
||||
steps=self.actions["automatic1111"].config["steps"].value,
|
||||
type_model=self.actions["automatic1111"].config["model_type"].value,
|
||||
)
|
||||
else:
|
||||
return self.automatic1111_default_render_settings
|
||||
|
||||
async def automatic1111_generate(self, prompt: Style, format: str):
|
||||
url = self.api_url
|
||||
resolution = self.resolution_from_format(
|
||||
format, self.automatic1111_render_settings.type_model
|
||||
)
|
||||
render_settings = self.automatic1111_render_settings
|
||||
payload = {
|
||||
"prompt": prompt.positive_prompt,
|
||||
"negative_prompt": prompt.negative_prompt,
|
||||
"steps": render_settings.steps,
|
||||
"width": resolution.width,
|
||||
"height": resolution.height,
|
||||
}
|
||||
|
||||
log.info("automatic1111_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url=f"{url}/sdapi/v1/txt2img", json=payload, timeout=90
|
||||
)
|
||||
|
||||
r = response.json()
|
||||
|
||||
# image = Image.open(io.BytesIO(base64.b64decode(r['images'][0])))
|
||||
# image.save('a1111-test.png')
|
||||
|
||||
#'log.info("automatic1111_generate", saved_to="a1111-test.png")
|
||||
|
||||
for image in r["images"]:
|
||||
await self.emit_image(image)
|
||||
|
||||
async def automatic1111_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /sdapi/v1/memory and on 200 will return True
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url=f"{self.api_url}/sdapi/v1/memory", timeout=2
|
||||
)
|
||||
return response.status_code == 200
|
||||
324
src/talemate/agents/visual/comfyui.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.comfyui")
|
||||
|
||||
|
||||
class Workflow(pydantic.BaseModel):
|
||||
nodes: dict
|
||||
|
||||
def set_resolution(self, resolution: Resolution):
|
||||
|
||||
# will collect all latent image nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Resolution"
|
||||
|
||||
# if there is no latent image node with the title "Talemate Resolution"
|
||||
# the first latent image node will be used
|
||||
|
||||
# resolution will be updated on the selected node
|
||||
|
||||
# if no latent image node is found a warning will be logged
|
||||
|
||||
latent_image_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
if node["class_type"] == "EmptyLatentImage":
|
||||
if not latent_image_node:
|
||||
latent_image_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Resolution":
|
||||
latent_image_node = node
|
||||
break
|
||||
|
||||
if not latent_image_node:
|
||||
log.warning("set_resolution", error="No latent image node found")
|
||||
return
|
||||
|
||||
latent_image_node["inputs"]["width"] = resolution.width
|
||||
latent_image_node["inputs"]["height"] = resolution.height
|
||||
|
||||
def set_prompt(self, prompt: str, negative_prompt: str = None):
|
||||
|
||||
# will collect all CLIPTextEncode nodes
|
||||
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Positive Prompt" and "Talemate Negative Prompt"
|
||||
#
|
||||
# if there is no CLIPTextEncode node with the title "Talemate Positive Prompt"
|
||||
# the first CLIPTextEncode node will be used
|
||||
#
|
||||
# if there is no CLIPTextEncode node with the title "Talemate Negative Prompt"
|
||||
# the second CLIPTextEncode node will be used
|
||||
#
|
||||
# prompt will be updated on the selected node
|
||||
|
||||
# if no CLIPTextEncode node is found an exception will be raised for
|
||||
# the positive prompt
|
||||
|
||||
# if no CLIPTextEncode node is found an exception will be raised for
|
||||
# the negative prompt if it is not None
|
||||
|
||||
positive_prompt_node = None
|
||||
negative_prompt_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
|
||||
if node["class_type"] == "CLIPTextEncode":
|
||||
if not positive_prompt_node:
|
||||
positive_prompt_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Positive Prompt":
|
||||
positive_prompt_node = node
|
||||
elif not negative_prompt_node:
|
||||
negative_prompt_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Negative Prompt":
|
||||
negative_prompt_node = node
|
||||
|
||||
if not positive_prompt_node:
|
||||
raise ValueError("No positive prompt node found")
|
||||
|
||||
positive_prompt_node["inputs"]["text"] = prompt
|
||||
|
||||
if negative_prompt and not negative_prompt_node:
|
||||
raise ValueError("No negative prompt node found")
|
||||
|
||||
if negative_prompt:
|
||||
negative_prompt_node["inputs"]["text"] = negative_prompt
|
||||
|
||||
def set_checkpoint(self, checkpoint: str):
|
||||
|
||||
# will collect all CheckpointLoaderSimple nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Load Checkpoint"
|
||||
|
||||
# if there is no CheckpointLoaderSimple node with the title "Talemate Load Checkpoint"
|
||||
# the first CheckpointLoaderSimple node will be used
|
||||
|
||||
# checkpoint will be updated on the selected node
|
||||
|
||||
# if no CheckpointLoaderSimple node is found a warning will be logged
|
||||
|
||||
checkpoint_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
if node["class_type"] == "CheckpointLoaderSimple":
|
||||
if not checkpoint_node:
|
||||
checkpoint_node = node
|
||||
elif node["_meta"]["title"] == "Talemate Load Checkpoint":
|
||||
checkpoint_node = node
|
||||
break
|
||||
|
||||
if not checkpoint_node:
|
||||
log.warning("set_checkpoint", error="No checkpoint node found")
|
||||
return
|
||||
|
||||
checkpoint_node["inputs"]["ckpt_name"] = checkpoint
|
||||
|
||||
def set_seeds(self):
|
||||
for node in self.nodes.values():
|
||||
for field in node.get("inputs", {}).keys():
|
||||
if field == "noise_seed":
|
||||
node["inputs"]["noise_seed"] = random.randint(0, 999999999999999)
|
||||
|
||||
|
||||
@register(backend_name="comfyui", label="ComfyUI")
|
||||
class ComfyUIMixin:
|
||||
|
||||
comfyui_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"comfyui": AgentAction(
|
||||
enabled=True,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="comfyui"
|
||||
),
|
||||
label="ComfyUI Settings",
|
||||
description="Setting overrides for the comfyui backend",
|
||||
config={
|
||||
"api_url": AgentActionConfig(
|
||||
type="text",
|
||||
value="http://localhost:8188",
|
||||
label="API URL",
|
||||
description="The URL of the backend API",
|
||||
),
|
||||
"workflow": AgentActionConfig(
|
||||
type="text",
|
||||
value="default-sdxl.json",
|
||||
label="Workflow",
|
||||
description="The workflow to use for comfyui (workflow file name inside ./templates/comfyui-workflows)",
|
||||
),
|
||||
"checkpoint": AgentActionConfig(
|
||||
type="text",
|
||||
value="default",
|
||||
label="Checkpoint",
|
||||
choices=[],
|
||||
description="The main checkpoint to use.",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def comfyui_workflow_filename(self):
|
||||
base_name = self.actions["comfyui"].config["workflow"].value
|
||||
|
||||
# make absolute path
|
||||
abs_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"templates",
|
||||
"comfyui-workflows",
|
||||
base_name,
|
||||
)
|
||||
|
||||
return abs_path
|
||||
|
||||
@property
|
||||
def comfyui_workflow_is_sdxl(self) -> bool:
|
||||
"""
|
||||
Returns true if `sdxl` is in worhflow file name (case insensitive)
|
||||
"""
|
||||
|
||||
return "sdxl" in self.comfyui_workflow_filename.lower()
|
||||
|
||||
@property
|
||||
def comfyui_workflow(self) -> Workflow:
|
||||
workflow = self.comfyui_workflow_filename
|
||||
if not workflow:
|
||||
raise ValueError("No comfyui workflow file specified")
|
||||
|
||||
with open(workflow, "r") as f:
|
||||
return Workflow(nodes=json.load(f))
|
||||
|
||||
@property
|
||||
async def comfyui_object_info(self):
|
||||
if hasattr(self, "_comfyui_object_info"):
|
||||
return self._comfyui_object_info
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/object_info")
|
||||
self._comfyui_object_info = response.json()
|
||||
|
||||
return self._comfyui_object_info
|
||||
|
||||
@property
|
||||
async def comfyui_checkpoints(self):
|
||||
loader_node = (await self.comfyui_object_info)["CheckpointLoaderSimple"]
|
||||
_checkpoints = loader_node["input"]["required"]["ckpt_name"][0]
|
||||
return [
|
||||
{"label": checkpoint, "value": checkpoint} for checkpoint in _checkpoints
|
||||
]
|
||||
|
||||
async def comfyui_get_image(self, filename: str, subfolder: str, folder_type: str):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/view?{url_values}")
|
||||
return response.content
|
||||
|
||||
async def comfyui_get_history(self, prompt_id: str):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/history/{prompt_id}")
|
||||
return response.json()
|
||||
|
||||
async def comfyui_get_images(self, prompt_id: str, max_wait: int = 60.0):
|
||||
output_images = {}
|
||||
history = {}
|
||||
|
||||
start = time.time()
|
||||
|
||||
while not history:
|
||||
log.info(
|
||||
"comfyui_get_images", waiting_for_history=True, prompt_id=prompt_id
|
||||
)
|
||||
history = await self.comfyui_get_history(prompt_id)
|
||||
await asyncio.sleep(1.0)
|
||||
if time.time() - start > max_wait:
|
||||
raise TimeoutError("Max wait time exceeded")
|
||||
|
||||
for node_id, node_output in history[prompt_id]["outputs"].items():
|
||||
if "images" in node_output:
|
||||
images_output = []
|
||||
for image in node_output["images"]:
|
||||
image_data = await self.comfyui_get_image(
|
||||
image["filename"], image["subfolder"], image["type"]
|
||||
)
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
return output_images
|
||||
|
||||
async def comfyui_generate(self, prompt: Style, format: str):
|
||||
url = self.api_url
|
||||
workflow = self.comfyui_workflow
|
||||
is_sdxl = self.comfyui_workflow_is_sdxl
|
||||
|
||||
resolution = self.resolution_from_format(format, "sdxl" if is_sdxl else "sd15")
|
||||
|
||||
workflow.set_resolution(resolution)
|
||||
workflow.set_prompt(prompt.positive_prompt, prompt.negative_prompt)
|
||||
workflow.set_seeds()
|
||||
workflow.set_checkpoint(self.actions["comfyui"].config["checkpoint"].value)
|
||||
|
||||
payload = {"prompt": workflow.model_dump().get("nodes")}
|
||||
|
||||
log.info("comfyui_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=f"{url}/prompt", json=payload, timeout=90)
|
||||
|
||||
log.info("comfyui_generate", response=response.text)
|
||||
|
||||
r = response.json()
|
||||
|
||||
prompt_id = r["prompt_id"]
|
||||
|
||||
images = await self.comfyui_get_images(prompt_id)
|
||||
for node_id, node_images in images.items():
|
||||
for i, image in enumerate(node_images):
|
||||
await self.emit_image(base64.b64encode(image).decode("utf-8"))
|
||||
# image = Image.open(io.BytesIO(image))
|
||||
# image.save(f'comfyui-test.png')
|
||||
|
||||
async def comfyui_apply_config(
|
||||
self, backend_changed: bool = False, *args, **kwargs
|
||||
):
|
||||
log.debug(
|
||||
"comfyui_apply_config",
|
||||
backend_changed=backend_changed,
|
||||
enabled=self.enabled,
|
||||
)
|
||||
if (not self.initialized or backend_changed) and self.enabled:
|
||||
checkpoints = await self.comfyui_checkpoints
|
||||
selected_checkpoint = self.actions["comfyui"].config["checkpoint"].value
|
||||
self.actions["comfyui"].config["checkpoint"].choices = checkpoints
|
||||
self.actions["comfyui"].config["checkpoint"].value = selected_checkpoint
|
||||
|
||||
async def comfyui_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /system_stats and on 200 will return True
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url=f"{self.api_url}/system_stats", timeout=2)
|
||||
return response.status_code == 200
|
||||
68
src/talemate/agents/visual/commands.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from talemate.agents.visual.context import VIS_TYPES, VisualContext
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.instance import get_agent
|
||||
|
||||
__all__ = [
|
||||
"CmdVisualizeTestGenerate",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeTestGenerate(TalemateCommand):
|
||||
"""
|
||||
Generates a visual test
|
||||
"""
|
||||
|
||||
name = "visual_test_generate"
|
||||
description = "Will generate a visual test"
|
||||
aliases = ["vis_test", "vtg"]
|
||||
|
||||
label = "Visualize test"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
prompt = self.args[0]
|
||||
with VisualContext(vis_type=VIS_TYPES.UNSPECIFIED):
|
||||
await visual.generate(prompt)
|
||||
return True
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeEnvironment(TalemateCommand):
|
||||
"""
|
||||
Shows the environment
|
||||
"""
|
||||
|
||||
name = "visual_environment"
|
||||
description = "Will show the environment"
|
||||
aliases = ["vis_env"]
|
||||
|
||||
label = "Visualize environment"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
await visual.generate_environment_background(
|
||||
instructions=self.args[0] if len(self.args) > 0 else None
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@register
|
||||
class CmdVisualizeCharacter(TalemateCommand):
|
||||
"""
|
||||
Shows a character
|
||||
"""
|
||||
|
||||
name = "visual_character"
|
||||
description = "Will show a character"
|
||||
aliases = ["vis_char"]
|
||||
|
||||
label = "Visualize character"
|
||||
|
||||
async def run(self):
|
||||
visual = get_agent("visual")
|
||||
character_name = self.args[0]
|
||||
instructions = self.args[1] if len(self.args) > 1 else None
|
||||
await visual.generate_character_portrait(character_name, instructions)
|
||||
return True
|
||||
55
src/talemate/agents/visual/context.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import contextvars
|
||||
import enum
|
||||
from typing import Union
|
||||
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"VIS_TYPES",
|
||||
"visual_context",
|
||||
"VisualContext",
|
||||
]
|
||||
|
||||
|
||||
class VIS_TYPES(str, enum.Enum):
|
||||
UNSPECIFIED = "UNSPECIFIED"
|
||||
ENVIRONMENT = "ENVIRONMENT"
|
||||
CHARACTER = "CHARACTER"
|
||||
ITEM = "ITEM"
|
||||
|
||||
|
||||
visual_context = contextvars.ContextVar("visual_context", default=None)
|
||||
|
||||
|
||||
class VisualContextState(pydantic.BaseModel):
|
||||
character_name: Union[str, None] = None
|
||||
instructions: Union[str, None] = None
|
||||
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT
|
||||
prompt: Union[str, None] = None
|
||||
prepared_prompt: Union[str, None] = None
|
||||
format: Union[str, None] = None
|
||||
|
||||
|
||||
class VisualContext:
|
||||
def __init__(
|
||||
self,
|
||||
character_name: Union[str, None] = None,
|
||||
instructions: Union[str, None] = None,
|
||||
vis_type: VIS_TYPES = VIS_TYPES.ENVIRONMENT,
|
||||
prompt: Union[str, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.state = VisualContextState(
|
||||
character_name=character_name,
|
||||
instructions=instructions,
|
||||
vis_type=vis_type,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.token = visual_context.set(self.state)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
visual_context.reset(self.token)
|
||||
return False
|
||||
17
src/talemate/agents/visual/handlers.py
Normal file
@@ -0,0 +1,17 @@
|
||||
__all__ = [
|
||||
"HANDLERS",
|
||||
"register",
|
||||
]
|
||||
|
||||
HANDLERS = {}
|
||||
|
||||
|
||||
class register:
|
||||
|
||||
def __init__(self, backend_name: str, label: str):
|
||||
self.backend_name = backend_name
|
||||
self.label = label
|
||||
|
||||
def __call__(self, mixin_cls):
|
||||
HANDLERS[self.backend_name] = {"label": self.label, "cls": mixin_cls}
|
||||
return mixin_cls
|
||||
127
src/talemate/agents/visual/openai_image.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.openai_image")
|
||||
|
||||
|
||||
@register(backend_name="openai_image", label="OpenAI")
|
||||
class OpenAIImageMixin:
|
||||
|
||||
openai_image_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
"openai_image": AgentAction(
|
||||
enabled=False,
|
||||
condition=AgentActionConditional(
|
||||
attribute="_config.config.backend", value="openai_image"
|
||||
),
|
||||
label="OpenAI Image Generation Advanced Settings",
|
||||
description="Setting overrides for the openai backend",
|
||||
config={
|
||||
"model_type": AgentActionConfig(
|
||||
type="text",
|
||||
value="dall-e-3",
|
||||
choices=[
|
||||
{"value": "dall-e-3", "label": "DALL-E 3"},
|
||||
{"value": "dall-e-2", "label": "DALL-E 2"},
|
||||
],
|
||||
label="Model Type",
|
||||
description="Image generation model",
|
||||
),
|
||||
"quality": AgentActionConfig(
|
||||
type="text",
|
||||
value="standard",
|
||||
choices=[
|
||||
{"value": "standard", "label": "Standard"},
|
||||
{"value": "hd", "label": "HD"},
|
||||
],
|
||||
label="Quality",
|
||||
description="Image generation quality",
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
@property
|
||||
def openai_model_type(self):
|
||||
return self.actions["openai_image"].config["model_type"].value
|
||||
|
||||
@property
|
||||
def openai_quality(self):
|
||||
return self.actions["openai_image"].config["quality"].value
|
||||
|
||||
async def openai_image_generate(self, prompt: Style, format: str):
|
||||
"""
|
||||
#
|
||||
from openai import OpenAI
|
||||
client = OpenAI()
|
||||
|
||||
response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt="a white siamese cat",
|
||||
size="1024x1024",
|
||||
quality="standard",
|
||||
n=1,
|
||||
)
|
||||
|
||||
image_url = response.data[0].url
|
||||
"""
|
||||
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
# When using DALL·E 3, images can have a size of 1024x1024, 1024x1792 or 1792x1024 pixels.#
|
||||
|
||||
if format == "portrait":
|
||||
resolution = Resolution(width=1024, height=1792)
|
||||
elif format == "landscape":
|
||||
resolution = Resolution(width=1792, height=1024)
|
||||
else:
|
||||
resolution = Resolution(width=1024, height=1024)
|
||||
|
||||
response = await client.images.generate(
|
||||
model=self.openai_model_type,
|
||||
prompt=prompt.positive_prompt,
|
||||
size=f"{resolution.width}x{resolution.height}",
|
||||
quality=self.openai_quality,
|
||||
n=1,
|
||||
)
|
||||
|
||||
download_url = response.data[0].url
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url, timeout=90)
|
||||
# bytes to base64encoded
|
||||
image = base64.b64encode(response.content).decode("utf-8")
|
||||
await self.emit_image(image)
|
||||
|
||||
async def openai_image_ready(self) -> bool:
|
||||
"""
|
||||
Will send a GET to /sdapi/v1/memory and on 200 will return True
|
||||
"""
|
||||
|
||||
if not self.openai_api_key:
|
||||
raise ValueError("OpenAI API Key not set")
|
||||
|
||||
return True
|
||||
32
src/talemate/agents/visual/schema.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"RenderSettings",
|
||||
"Resolution",
|
||||
"RESOLUTION_MAP",
|
||||
]
|
||||
|
||||
RESOLUTION_MAP = {}
|
||||
|
||||
|
||||
class RenderSettings(pydantic.BaseModel):
|
||||
type_model: str = "sdxl"
|
||||
steps: int = 40
|
||||
|
||||
|
||||
class Resolution(pydantic.BaseModel):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
RESOLUTION_MAP["sdxl"] = {
|
||||
"portrait": Resolution(width=832, height=1216),
|
||||
"landscape": Resolution(width=1216, height=832),
|
||||
"square": Resolution(width=1024, height=1024),
|
||||
}
|
||||
|
||||
RESOLUTION_MAP["sd15"] = {
|
||||
"portrait": Resolution(width=512, height=768),
|
||||
"landscape": Resolution(width=768, height=512),
|
||||
"square": Resolution(width=768, height=768),
|
||||
}
|
||||
112
src/talemate/agents/visual/style.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"Style",
|
||||
"STYLE_MAP",
|
||||
"THEME_MAP",
|
||||
"MAJOR_STYLES",
|
||||
"combine_styles",
|
||||
]
|
||||
|
||||
STYLE_MAP = {}
|
||||
THEME_MAP = {}
|
||||
MAJOR_STYLES = {}
|
||||
|
||||
|
||||
class Style(pydantic.BaseModel):
|
||||
keywords: list[str] = pydantic.Field(default_factory=list)
|
||||
negative_keywords: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def positive_prompt(self):
|
||||
return ", ".join(self.keywords)
|
||||
|
||||
@property
|
||||
def negative_prompt(self):
|
||||
return ", ".join(self.negative_keywords)
|
||||
|
||||
def __str__(self):
|
||||
return f"POSITIVE: {self.positive_prompt}\nNEGATIVE: {self.negative_prompt}"
|
||||
|
||||
def load(self, prompt: str, negative_prompt: str = ""):
|
||||
self.keywords = prompt.split(", ")
|
||||
self.negative_keywords = negative_prompt.split(", ")
|
||||
return self
|
||||
|
||||
def prepend(self, *styles):
|
||||
for style in styles:
|
||||
for idx in range(len(style.keywords) - 1, -1, -1):
|
||||
kw = style.keywords[idx]
|
||||
if kw not in self.keywords:
|
||||
self.keywords.insert(0, kw)
|
||||
|
||||
for idx in range(len(style.negative_keywords) - 1, -1, -1):
|
||||
kw = style.negative_keywords[idx]
|
||||
if kw not in self.negative_keywords:
|
||||
self.negative_keywords.insert(0, kw)
|
||||
|
||||
return self
|
||||
|
||||
def append(self, *styles):
|
||||
for style in styles:
|
||||
for kw in style.keywords:
|
||||
if kw not in self.keywords:
|
||||
self.keywords.append(kw)
|
||||
|
||||
for kw in style.negative_keywords:
|
||||
if kw not in self.negative_keywords:
|
||||
self.negative_keywords.append(kw)
|
||||
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
return Style(
|
||||
keywords=self.keywords.copy(),
|
||||
negative_keywords=self.negative_keywords.copy(),
|
||||
)
|
||||
|
||||
|
||||
# Almost taken straight from some of the fooocus style presets, credit goes to the original author
|
||||
|
||||
STYLE_MAP["digital_art"] = Style(
|
||||
keywords="digital artwork, masterpiece, best quality, high detail".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["concept_art"] = Style(
|
||||
keywords="concept art, conceptual sketch, masterpiece, best quality, high detail".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["ink_illustration"] = Style(
|
||||
keywords="ink illustration, painting, masterpiece, best quality".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["anime"] = Style(
|
||||
keywords="anime, masterpiece, best quality, illustration".split(", "),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo, 3d".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["character_portrait"] = Style(keywords="solo, looking at viewer".split(", "))
|
||||
|
||||
STYLE_MAP["environment"] = Style(
|
||||
keywords="scenery, environment, background, postcard".split(", "),
|
||||
negative_keywords="character, portrait, looking at viewer, people".split(", "),
|
||||
)
|
||||
|
||||
MAJOR_STYLES = [
|
||||
{"value": "digital_art", "label": "Digital Art"},
|
||||
{"value": "concept_art", "label": "Concept Art"},
|
||||
{"value": "ink_illustration", "label": "Ink Illustration"},
|
||||
{"value": "anime", "label": "Anime"},
|
||||
]
|
||||
|
||||
|
||||
def combine_styles(*styles):
|
||||
keywords = []
|
||||
for style in styles:
|
||||
keywords.extend(style.keywords)
|
||||
return Style(keywords=list(set(keywords)))
|
||||
84
src/talemate/agents/visual/websocket_handler.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
|
||||
from .context import VisualContext, VisualContextState
|
||||
|
||||
__all__ = [
|
||||
"VisualWebsocketHandler",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.server.visual")
|
||||
|
||||
|
||||
class SetCoverImagePayload(pydantic.BaseModel):
|
||||
base64: str
|
||||
context: Union[VisualContextState, None] = None
|
||||
|
||||
|
||||
class RegeneratePayload(pydantic.BaseModel):
|
||||
context: Union[VisualContextState, None] = None
|
||||
|
||||
|
||||
class VisualWebsocketHandler(Plugin):
|
||||
router = "visual"
|
||||
|
||||
async def handle_regenerate(self, data: dict):
|
||||
"""
|
||||
Regenerate the image based on the context.
|
||||
"""
|
||||
|
||||
payload = RegeneratePayload(**data)
|
||||
|
||||
context = payload.context
|
||||
|
||||
visual = get_agent("visual")
|
||||
|
||||
with VisualContext(**context.model_dump()):
|
||||
await visual.generate(format="")
|
||||
|
||||
async def handle_cover_image(self, data: dict):
|
||||
"""
|
||||
Sets the cover image for a character and the scene.
|
||||
"""
|
||||
|
||||
payload = SetCoverImagePayload(**data)
|
||||
|
||||
context = payload.context
|
||||
scene = self.scene
|
||||
|
||||
if context and context.character_name:
|
||||
|
||||
character = scene.get_character(context.character_name)
|
||||
|
||||
if not character:
|
||||
log.error("character not found", character_name=context.character_name)
|
||||
return
|
||||
|
||||
asset = scene.assets.add_asset_from_image_data(payload.base64)
|
||||
|
||||
log.info("setting scene cover image", character_name=context.character_name)
|
||||
scene.assets.cover_image = asset.id
|
||||
|
||||
log.info(
|
||||
"setting character cover image", character_name=context.character_name
|
||||
)
|
||||
character.cover_image = asset.id
|
||||
|
||||
scene.emit_status()
|
||||
self.websocket_handler.request_scene_assets([asset.id])
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_asset_character_cover_image",
|
||||
"asset_id": asset.id,
|
||||
"asset": self.scene.assets.get_asset_bytes_as_base64(asset.id),
|
||||
"media_type": asset.media_type,
|
||||
"character": character.name,
|
||||
}
|
||||
)
|
||||
return
|
||||
@@ -187,7 +187,7 @@ class WorldStateAgent(Agent):
|
||||
|
||||
await self.check_pin_conditions()
|
||||
|
||||
async def update_world_state(self):
|
||||
async def update_world_state(self, force: bool = False):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
@@ -206,7 +206,7 @@ class WorldStateAgent(Agent):
|
||||
self.next_update % self.actions["update_world_state"].config["turns"].value
|
||||
!= 0
|
||||
or self.next_update == 0
|
||||
):
|
||||
) and not force:
|
||||
self.next_update += 1
|
||||
return
|
||||
|
||||
@@ -349,11 +349,15 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
text: str,
|
||||
instruction: str,
|
||||
short: bool = False,
|
||||
):
|
||||
|
||||
kind = "analyze_freeform_short" if short else "analyze_freeform"
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-follow-instruction",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -376,11 +380,13 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
text: str,
|
||||
query: str,
|
||||
short: bool = False,
|
||||
):
|
||||
kind = "analyze_freeform_short" if short else "analyze_freeform"
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-answer-question",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -439,6 +445,7 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
name: str,
|
||||
text: str = None,
|
||||
alteration_instructions: str = None,
|
||||
):
|
||||
"""
|
||||
Attempts to extract a character sheet from the given text.
|
||||
@@ -453,6 +460,8 @@ class WorldStateAgent(Agent):
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"name": name,
|
||||
"character": self.scene.get_character(name),
|
||||
"alteration_instructions": alteration_instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -518,23 +527,37 @@ class WorldStateAgent(Agent):
|
||||
if reset and reinforcement.insert == "sequential":
|
||||
self.scene.pop_history(typ="reinforcement", source=source, all=True)
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
kind = "analyze_freeform_medium_short"
|
||||
else:
|
||||
kind = "analyze_freeform"
|
||||
|
||||
answer = await Prompt.request(
|
||||
"world_state.update-reinforcements",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
kind,
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"question": reinforcement.question,
|
||||
"instructions": reinforcement.instructions or "",
|
||||
"character": self.scene.get_character(reinforcement.character)
|
||||
if reinforcement.character
|
||||
else None,
|
||||
"character": (
|
||||
self.scene.get_character(reinforcement.character)
|
||||
if reinforcement.character
|
||||
else None
|
||||
),
|
||||
"answer": (reinforcement.answer if not reset else None) or "",
|
||||
"reinforcement": reinforcement,
|
||||
},
|
||||
)
|
||||
|
||||
# sequential reinforcment should be single sentence so we
|
||||
# split on line breaks and take the first line in case the
|
||||
# LLM did not understand the request and returned a longer response
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
answer = answer.split("\n")[0]
|
||||
|
||||
reinforcement.answer = answer
|
||||
reinforcement.due = reinforcement.interval
|
||||
|
||||
@@ -735,3 +758,28 @@ class WorldStateAgent(Agent):
|
||||
)
|
||||
|
||||
return is_leaving.lower().startswith("y")
|
||||
|
||||
@set_processing
|
||||
async def manager(self, action_name: str, *args, **kwargs):
|
||||
"""
|
||||
Executes a world state manager action through self.scene.world_state_manager
|
||||
"""
|
||||
|
||||
manager = self.scene.world_state_manager
|
||||
|
||||
try:
|
||||
fn = getattr(manager, action_name, None)
|
||||
|
||||
if not fn:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
return await fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"worldstate.manager",
|
||||
action_name=action_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
error=e,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
import talemate.client.runpod
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.anthropic import AnthropicClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
|
||||
224
src/talemate/client/anthropic.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from anthropic import AsyncAnthropic, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"AnthropicClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "claude-3-sonnet-20240229"
|
||||
|
||||
|
||||
@register()
|
||||
class AnthropicClient(ClientBase):
|
||||
"""
|
||||
Anthropic client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "anthropic"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Anthropic"
|
||||
title: str = "Anthropic"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="claude-3-sonnet-20240229", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def anthropic_api_key(self):
|
||||
return self.config.get("anthropic", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.anthropic_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"anthropic_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.anthropic_api_key:
|
||||
self.client = AsyncAnthropic(api_key="sk-1111")
|
||||
log.error("No anthropic API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "claude-3-opus-20240229"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncAnthropic(api_key=self.anthropic_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"anthropic set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.output_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.input_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.anthropic_api_key:
|
||||
raise Exception("No anthropic API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system_message,
|
||||
messages=[human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
log.debug("generated response", response=response.content)
|
||||
|
||||
response = response.content[0].text
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="anthropic API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
A unified client base, based on the openai API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
@@ -32,6 +33,19 @@ REMOTE_SERVICES = [
|
||||
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
|
||||
|
||||
|
||||
class PromptData(pydantic.BaseModel):
|
||||
kind: str
|
||||
prompt: str
|
||||
response: str
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
client_name: str
|
||||
client_type: str
|
||||
time: Union[float, int]
|
||||
agent_stack: list[str] = pydantic.Field(default_factory=list)
|
||||
generation_parameters: dict = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
class ErrorAction(pydantic.BaseModel):
|
||||
title: str
|
||||
action_name: str
|
||||
@@ -44,6 +58,14 @@ class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 4096
|
||||
|
||||
|
||||
class ExtraField(pydantic.BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
label: str
|
||||
required: bool
|
||||
description: str
|
||||
|
||||
|
||||
class ClientBase:
|
||||
api_url: str
|
||||
model_name: str
|
||||
@@ -77,7 +99,9 @@ class ClientBase:
|
||||
self.name = name or self.client_type
|
||||
self.log = structlog.get_logger(f"client.{self.client_type}")
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 4096
|
||||
)
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def __str__(self):
|
||||
@@ -121,7 +145,7 @@ class ClientBase:
|
||||
self.api_url = kwargs["api_url"]
|
||||
|
||||
if kwargs.get("max_token_length"):
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.max_token_length = int(kwargs["max_token_length"])
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
@@ -154,7 +178,7 @@ class ClientBase:
|
||||
"""
|
||||
|
||||
if self.decensor_enabled:
|
||||
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "story" in kind:
|
||||
@@ -179,9 +203,11 @@ class ClientBase:
|
||||
return system_prompts.ANALYST
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE
|
||||
|
||||
if "visualize" in kind:
|
||||
return system_prompts.VISUALIZE
|
||||
|
||||
else:
|
||||
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR_NO_DECENSOR
|
||||
if "story" in kind:
|
||||
@@ -206,7 +232,9 @@ class ClientBase:
|
||||
return system_prompts.ANALYST_NO_DECENSOR
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE_NO_DECENSOR
|
||||
|
||||
if "visualize" in kind:
|
||||
return system_prompts.VISUALIZE_NO_DECENSOR
|
||||
|
||||
return system_prompts.BASIC
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
@@ -235,22 +263,27 @@ class ClientBase:
|
||||
|
||||
prompt_template_example, prompt_template_file = self.prompt_template_example()
|
||||
|
||||
data = {
|
||||
"api_key": self.api_key,
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
),
|
||||
"template_file": prompt_template_file,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"error_action": None,
|
||||
}
|
||||
|
||||
for field_name in getattr(self.Meta(), "extra_fields", {}).keys():
|
||||
data[field_name] = getattr(self, field_name, None)
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"api_key": self.api_key,
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
),
|
||||
"template_file": prompt_template_file,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"error_action": None,
|
||||
},
|
||||
data=data,
|
||||
)
|
||||
|
||||
if status_change:
|
||||
@@ -330,6 +363,11 @@ class ClientBase:
|
||||
f"{character}:" for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
dialog_stopping_strings += [
|
||||
f"{character.upper()}\n"
|
||||
for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
if "extra_stopping_strings" in parameters:
|
||||
parameters["extra_stopping_strings"] += dialog_stopping_strings
|
||||
else:
|
||||
@@ -372,6 +410,9 @@ class ClientBase:
|
||||
"""
|
||||
|
||||
try:
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
|
||||
self.emit_status(processing=True)
|
||||
await self.status()
|
||||
|
||||
@@ -411,21 +452,30 @@ class ClientBase:
|
||||
response = response.split(stopping_string)[0]
|
||||
break
|
||||
|
||||
agent_context = active_agent.get()
|
||||
|
||||
emit(
|
||||
"prompt_sent",
|
||||
data={
|
||||
"kind": kind,
|
||||
"prompt": finalized_prompt,
|
||||
"response": response,
|
||||
"prompt_tokens": token_length,
|
||||
"response_tokens": self.count_tokens(response),
|
||||
"time": time_end - time_start,
|
||||
},
|
||||
data=PromptData(
|
||||
kind=kind,
|
||||
prompt=finalized_prompt,
|
||||
response=response,
|
||||
prompt_tokens=self._returned_prompt_tokens or token_length,
|
||||
response_tokens=self._returned_response_tokens
|
||||
or self.count_tokens(response),
|
||||
agent_stack=agent_context.agent_stack if agent_context else [],
|
||||
client_name=self.name,
|
||||
client_type=self.client_type,
|
||||
time=time_end - time_start,
|
||||
generation_parameters=prompt_param,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
finally:
|
||||
self.emit_status(processing=False)
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
|
||||
async def auto_break_repetition(
|
||||
self,
|
||||
|
||||
34
src/talemate/client/custom/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger("talemate.client.custom")
|
||||
|
||||
# import every submodule in this directory
|
||||
#
|
||||
# each directory in this directory is a submodule
|
||||
|
||||
# get the current directory
|
||||
current_directory = os.path.dirname(__file__)
|
||||
|
||||
# get all subdirectories
|
||||
subdirectories = [
|
||||
os.path.join(current_directory, name)
|
||||
for name in os.listdir(current_directory)
|
||||
if os.path.isdir(os.path.join(current_directory, name))
|
||||
]
|
||||
|
||||
# import every submodule
|
||||
|
||||
for subdirectory in subdirectories:
|
||||
# get the name of the submodule
|
||||
submodule_name = os.path.basename(subdirectory)
|
||||
|
||||
if submodule_name.startswith("__"):
|
||||
continue
|
||||
|
||||
log.info("activating custom client", module=submodule_name)
|
||||
|
||||
# import the submodule
|
||||
importlib.import_module(f".{submodule_name}", __package__)
|
||||
@@ -0,0 +1,5 @@
|
||||
Each client should be in its own subdirectory.
|
||||
|
||||
The subdirectory itself must be a valid python module.
|
||||
|
||||
Check out docs/dev/client/example/test for a very simplistic custom client example.
|
||||
@@ -9,6 +9,7 @@ class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
|
||||
|
||||
@register()
|
||||
class LMStudioClient(ClientBase):
|
||||
client_type = "lmstudio"
|
||||
|
||||
232
src/talemate/client/mistral.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import json
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import tiktoken
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"MistralAIClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "open-mixtral-8x7b"
|
||||
|
||||
|
||||
@register()
|
||||
class MistralAIClient(ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "mistral"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "MistralAI"
|
||||
title: str = "MistralAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="open-mixtral-8x7b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def mistralai_api_key(self):
|
||||
return self.config.get("mistralai", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.mistralai_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"mistralai_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.mistralai_api_key:
|
||||
self.client = AsyncOpenAI(api_key="sk-1111")
|
||||
log.error("No mistral.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "open-mixtral-8x7b"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.mistralai_api_key, base_url="https://api.mistral.ai/v1/"
|
||||
)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"mistral.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.prompt_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.mistralai_api_key:
|
||||
raise Exception("No mistral.ai API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[system_message, human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="mistral.ai API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -38,7 +38,6 @@ log = structlog.get_logger("talemate.model_prompts")
|
||||
|
||||
|
||||
class ModelPrompt:
|
||||
|
||||
"""
|
||||
Will attempt to load an LLM prompt template based on the model name
|
||||
|
||||
|
||||
@@ -16,6 +16,25 @@ __all__ = [
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
]
|
||||
|
||||
JSON_OBJECT_RESPONSE_MODELS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-3.5-turbo-0125",
|
||||
]
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
@@ -90,14 +109,7 @@ class OpenAIClient(ClientBase):
|
||||
name_prefix: str = "OpenAI"
|
||||
title: str = "OpenAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-turbo-preview",
|
||||
]
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
@@ -165,6 +177,9 @@ class OpenAIClient(ClientBase):
|
||||
if not self.model_name:
|
||||
self.model_name = "gpt-3.5-turbo-16k"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
@@ -216,7 +231,7 @@ class OpenAIClient(ClientBase):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
@@ -229,6 +244,15 @@ class OpenAIClient(ClientBase):
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
# GPT-3.5 models tend to run away with the generated
|
||||
# response size so we allow talemate to set the max_tokens
|
||||
#
|
||||
# GPT-4 on the other hand seems to benefit from letting it
|
||||
# decide the generation length naturally and it will generally
|
||||
# produce reasonably sized responses
|
||||
if self.model_name.startswith("gpt-3.5-"):
|
||||
valid_keys.append("max_tokens")
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
@@ -242,10 +266,14 @@ class OpenAIClient(ClientBase):
|
||||
raise Exception("No OpenAI API key set")
|
||||
|
||||
# only gpt-4-* supports enforcing json object
|
||||
supports_json_object = self.model_name.startswith("gpt-4-")
|
||||
supports_json_object = (
|
||||
self.model_name.startswith("gpt-4-")
|
||||
or self.model_name in JSON_OBJECT_RESPONSE_MODELS
|
||||
)
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nContinue this response: ")
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
@@ -255,8 +283,13 @@ class OpenAIClient(ClientBase):
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters, system_message=system_message)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
@@ -266,6 +299,17 @@ class OpenAIClient(ClientBase):
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if (
|
||||
not supports_json_object
|
||||
and expected_response
|
||||
and expected_response.startswith("{")
|
||||
):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
import urllib
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.openai_compat")
|
||||
|
||||
EXPERIMENTAL_DESCRIPTION = """Use this client if you want to connect to a service implementing an OpenAI-compatible API. Success is going to depend on the level of compatibility. Use the actual OpenAI client if you want to connect to OpenAI's API."""
|
||||
|
||||
@@ -13,12 +18,18 @@ class Defaults(pydantic.BaseModel):
|
||||
api_key: str = ""
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
api_handles_prompt_template: bool = False
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
api_handles_prompt_template: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAICompatibleClient(ClientBase):
|
||||
client_type = "openai_compat"
|
||||
conversation_retries = 5
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title: str = "OpenAI Compatible API"
|
||||
@@ -27,9 +38,22 @@ class OpenAICompatibleClient(ClientBase):
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"api_handles_prompt_template": ExtraField(
|
||||
name="api_handles_prompt_template",
|
||||
type="bool",
|
||||
label="API Handles Prompt Template",
|
||||
required=False,
|
||||
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored.",
|
||||
)
|
||||
}
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
def __init__(
|
||||
self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs
|
||||
):
|
||||
self.model_name = model
|
||||
self.api_key = api_key
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@@ -37,8 +61,12 @@ class OpenAICompatibleClient(ClientBase):
|
||||
return EXPERIMENTAL_DESCRIPTION
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key")
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key=self.api_key)
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
url = self.api_url
|
||||
self.client = AsyncOpenAI(base_url=url, api_key=self.api_key)
|
||||
self.model_name = (
|
||||
kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
)
|
||||
@@ -48,32 +76,33 @@ class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
|
||||
log.debug(
|
||||
"IS API HANDLING PROMPT TEMPLATE",
|
||||
api_handles_prompt_template=self.api_handles_prompt_template,
|
||||
)
|
||||
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def get_model_name(self):
|
||||
try:
|
||||
model_name = await super().get_model_name()
|
||||
except NotFoundError as e:
|
||||
# api does not implement model listing
|
||||
return self.model_name
|
||||
except Exception as e:
|
||||
self.log.error("get_model_name error", e=e)
|
||||
return self.model_name
|
||||
|
||||
# model name may be a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
is_filepath = "/" in model_name
|
||||
is_filepath_windows = "\\" in model_name
|
||||
|
||||
if is_filepath or is_filepath_windows:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
return self.model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
@@ -106,8 +135,14 @@ class OpenAICompatibleClient(ClientBase):
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 4096
|
||||
)
|
||||
if "api_key" in kwargs:
|
||||
self.api_auth = kwargs["api_key"]
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
@@ -121,62 +121,62 @@ def preset_for_kind(kind: str):
|
||||
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
|
||||
elif kind == "edit_fix_exposition":
|
||||
return PRESET_DIVINE_INTELLECT # Assuming fixing exposition uses the same preset as divine intellect
|
||||
elif kind == "visualize":
|
||||
return PRESET_SIMPLE_1
|
||||
else:
|
||||
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
|
||||
|
||||
|
||||
def max_tokens_for_kind(kind: str, total_budget: int):
|
||||
if kind == "conversation":
|
||||
return 75 # Example value, adjust as needed
|
||||
return 75
|
||||
elif kind == "conversation_old":
|
||||
return 75 # Example value, adjust as needed
|
||||
return 75
|
||||
elif kind == "conversation_long":
|
||||
return 300 # Example value, adjust as needed
|
||||
return 300
|
||||
elif kind == "conversation_select_talking_actor":
|
||||
return 30 # Example value, adjust as needed
|
||||
return 30
|
||||
elif kind == "summarize":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "analyze":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "analyze_creative":
|
||||
return 1024 # Example value, adjust as needed
|
||||
return 1024
|
||||
elif kind == "analyze_long":
|
||||
return 2048 # Example value, adjust as needed
|
||||
return 2048
|
||||
elif kind == "analyze_freeform":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "analyze_freeform_medium":
|
||||
return 192
|
||||
elif kind == "analyze_freeform_medium_short":
|
||||
return 128
|
||||
elif kind == "analyze_freeform_short":
|
||||
return 10 # Example value, adjust as needed
|
||||
return 10
|
||||
elif kind == "narrate":
|
||||
return 500 # Example value, adjust as needed
|
||||
return 500
|
||||
elif kind == "story":
|
||||
return 300 # Example value, adjust as needed
|
||||
return 300
|
||||
elif kind == "create":
|
||||
return min(
|
||||
1024, int(total_budget * 0.35)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(1024, int(total_budget * 0.35))
|
||||
elif kind == "create_concise":
|
||||
return min(
|
||||
400, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(400, int(total_budget * 0.25))
|
||||
elif kind == "create_precise":
|
||||
return min(
|
||||
400, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(400, int(total_budget * 0.25))
|
||||
elif kind == "create_short":
|
||||
return 25
|
||||
elif kind == "director":
|
||||
return min(
|
||||
192, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
return min(192, int(total_budget * 0.25))
|
||||
elif kind == "director_short":
|
||||
return 25 # Example value, adjust as needed
|
||||
return 25
|
||||
elif kind == "director_yesno":
|
||||
return 2 # Example value, adjust as needed
|
||||
return 2
|
||||
elif kind == "edit_dialogue":
|
||||
return 100 # Example value, adjust as needed
|
||||
return 100
|
||||
elif kind == "edit_add_detail":
|
||||
return 200 # Example value, adjust as needed
|
||||
return 200
|
||||
elif kind == "edit_fix_exposition":
|
||||
return 1024 # Example value, adjust as needed
|
||||
return 1024
|
||||
elif kind == "visualize":
|
||||
return 150
|
||||
else:
|
||||
return 150 # Default value if none of the kinds match
|
||||
|
||||
@@ -20,6 +20,8 @@ WORLD_STATE = str(Prompt.get("world_state.system-analyst"))
|
||||
|
||||
SUMMARIZE = str(Prompt.get("summarizer.system"))
|
||||
|
||||
VISUALIZE = str(Prompt.get("visual.system"))
|
||||
|
||||
# CAREBEAR PROMPTS
|
||||
|
||||
ROLEPLAY_NO_DECENSOR = str(Prompt.get("conversation.system-no-decensor"))
|
||||
@@ -32,10 +34,14 @@ DIRECTOR_NO_DECENSOR = str(Prompt.get("director.system-no-decensor"))
|
||||
|
||||
ANALYST_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-no-decensor"))
|
||||
|
||||
ANALYST_FREEFORM_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-freeform-no-decensor"))
|
||||
ANALYST_FREEFORM_NO_DECENSOR = str(
|
||||
Prompt.get("world_state.system-analyst-freeform-no-decensor")
|
||||
)
|
||||
|
||||
EDITOR_NO_DECENSOR = str(Prompt.get("editor.system-no-decensor"))
|
||||
|
||||
WORLD_STATE_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-no-decensor"))
|
||||
|
||||
SUMMARIZE_NO_DECENSOR = str(Prompt.get("summarizer.system-no-decensor"))
|
||||
|
||||
VISUALIZE_NO_DECENSOR = str(Prompt.get("visual.system-no-decensor"))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import random
|
||||
import re
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
@@ -28,20 +29,23 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
parameters["stop"] = parameters["stopping_strings"]
|
||||
|
||||
# Half temperature on -Yi- models
|
||||
if (
|
||||
self.model_name
|
||||
and "-yi-" in self.model_name.lower()
|
||||
and parameters["temperature"] > 0.1
|
||||
):
|
||||
parameters["temperature"] = parameters["temperature"] / 2
|
||||
if self.model_name and self.is_yi_model():
|
||||
parameters["smoothing_factor"] = 0.3
|
||||
# also half the temperature
|
||||
parameters["temperature"] = max(0.1, parameters["temperature"] / 2)
|
||||
log.debug(
|
||||
"halfing temperature for -yi- model",
|
||||
temperature=parameters["temperature"],
|
||||
"applying temperature smoothing for Yi model",
|
||||
)
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def is_yi_model(self):
|
||||
model_name = self.model_name.lower()
|
||||
# regex match for yi encased by non-word characters
|
||||
|
||||
return bool(re.search(r"[\-_]yi[\-_]", model_name))
|
||||
|
||||
async def get_model_name(self):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
|
||||
@@ -79,7 +79,7 @@ class CmdDeactivateCharacter(TalemateCommand):
|
||||
)
|
||||
message = await narrator.action_to_narration(
|
||||
"narrate_character_exit",
|
||||
self.scene.get_character(character_name),
|
||||
character=self.scene.get_character(character_name),
|
||||
direction=direction,
|
||||
)
|
||||
self.narrator_message(message)
|
||||
@@ -159,7 +159,7 @@ class CmdActivateCharacter(TalemateCommand):
|
||||
)
|
||||
message = await narrator.action_to_narration(
|
||||
"narrate_character_entry",
|
||||
self.scene.get_character(character_name),
|
||||
character=self.scene.get_character(character_name),
|
||||
direction=direction,
|
||||
)
|
||||
self.narrator_message(message)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import structlog
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.prompts.base import set_default_sectioning_handler
|
||||
@@ -12,6 +15,8 @@ __all__ = [
|
||||
"CmdRunAutomatic",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.commands.cmd_debug_tools")
|
||||
|
||||
|
||||
@register
|
||||
class CmdDebugOn(TalemateCommand):
|
||||
@@ -144,3 +149,32 @@ class CmdSetContentContext(TalemateCommand):
|
||||
self.scene.context = context
|
||||
|
||||
self.emit("system", f"Content context set to {context}")
|
||||
|
||||
|
||||
@register
|
||||
class CmdDumpHistory(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'dump_history' command
|
||||
"""
|
||||
|
||||
name = "dump_history"
|
||||
description = "Dump the history of the scene"
|
||||
aliases = []
|
||||
|
||||
async def run(self):
|
||||
for entry in self.scene.history:
|
||||
log.debug("dump_history", entry=entry)
|
||||
|
||||
|
||||
@register
|
||||
class CmdDumpSceneSerialization(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'dump_scene_serialization' command
|
||||
"""
|
||||
|
||||
name = "dump_scene_serialization"
|
||||
description = "Dump the scene serialization"
|
||||
aliases = []
|
||||
|
||||
async def run(self):
|
||||
log.debug("dump_scene_serialization", serialization=self.scene.json)
|
||||
|
||||
@@ -36,7 +36,11 @@ class CmdAIDialogue(TalemateCommand):
|
||||
if conversation_agent.actions["natural_flow"].enabled:
|
||||
await conversation_agent.apply_natural_flow(force=True, npcs_only=True)
|
||||
character_name = self.scene.next_actor
|
||||
actor = self.scene.get_character(character_name).actor
|
||||
try:
|
||||
actor = self.scene.get_character(character_name).actor
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
if actor.character.is_player:
|
||||
actor = random.choice(list(self.scene.get_npc_characters())).actor
|
||||
else:
|
||||
|
||||
@@ -26,6 +26,12 @@ class CmdRebuildArchive(TalemateCommand):
|
||||
ah for ah in self.scene.archived_history if ah.get("end") is None
|
||||
]
|
||||
|
||||
self.scene.ts = (
|
||||
self.scene.archived_history[-1].ts
|
||||
if self.scene.archived_history
|
||||
else "PT0S"
|
||||
)
|
||||
|
||||
while True:
|
||||
more = await summarizer.agent.build_archive(self.scene)
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ class CmdWorldState(TalemateCommand):
|
||||
|
||||
@register
|
||||
class CmdPersistCharacter(TalemateCommand):
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
@@ -177,7 +176,6 @@ class CmdPersistCharacter(TalemateCommand):
|
||||
|
||||
@register
|
||||
class CmdAddReinforcement(TalemateCommand):
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
@@ -204,7 +202,6 @@ class CmdAddReinforcement(TalemateCommand):
|
||||
|
||||
@register
|
||||
class CmdRemoveReinforcement(TalemateCommand):
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
@@ -236,7 +233,6 @@ class CmdRemoveReinforcement(TalemateCommand):
|
||||
|
||||
@register
|
||||
class CmdUpdateReinforcements(TalemateCommand):
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
@@ -258,7 +254,6 @@ class CmdUpdateReinforcements(TalemateCommand):
|
||||
|
||||
@register
|
||||
class CmdCheckPinConditions(TalemateCommand):
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
@@ -277,7 +272,6 @@ class CmdCheckPinConditions(TalemateCommand):
|
||||
|
||||
@register
|
||||
class CmdApplyWorldStateTemplate(TalemateCommand):
|
||||
|
||||
"""
|
||||
Will apply a world state template setting up
|
||||
automatic state tracking.
|
||||
@@ -337,7 +331,6 @@ class CmdApplyWorldStateTemplate(TalemateCommand):
|
||||
|
||||
@register
|
||||
class CmdSummarizeAndPin(TalemateCommand):
|
||||
|
||||
"""
|
||||
Will take a message index and then walk back N messages
|
||||
summarizing the scene and pinning it to the context.
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Union
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, TypeVar, Union, Any
|
||||
from typing_extensions import Annotated
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from talemate.agents.registry import get_agent_class
|
||||
from talemate.client.registry import get_client_class
|
||||
from talemate.emit import emit
|
||||
from talemate.scene_assets import Asset
|
||||
|
||||
@@ -16,6 +20,16 @@ if TYPE_CHECKING:
|
||||
log = structlog.get_logger("talemate.config")
|
||||
|
||||
|
||||
def scenes_dir():
|
||||
relative_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"..",
|
||||
"..",
|
||||
"scenes",
|
||||
)
|
||||
return os.path.abspath(relative_path)
|
||||
|
||||
|
||||
class Client(BaseModel):
|
||||
type: str
|
||||
name: str
|
||||
@@ -28,6 +42,9 @@ class Client(BaseModel):
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
ClientType = TypeVar("ClientType", bound=Client)
|
||||
|
||||
|
||||
class AgentActionConfig(BaseModel):
|
||||
value: Union[int, float, str, bool, None] = None
|
||||
|
||||
@@ -66,6 +83,7 @@ class GamePlayerCharacter(BaseModel):
|
||||
class General(BaseModel):
|
||||
auto_save: bool = True
|
||||
auto_progress: bool = True
|
||||
max_backscroll: int = 512
|
||||
|
||||
|
||||
class StateReinforcementTemplate(BaseModel):
|
||||
@@ -87,6 +105,9 @@ class WorldStateTemplates(BaseModel):
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
def get_template(self, name: str) -> Union[StateReinforcementTemplate, None]:
|
||||
return self.state_reinforcement.get(name)
|
||||
|
||||
|
||||
class WorldState(BaseModel):
|
||||
templates: WorldStateTemplates = WorldStateTemplates()
|
||||
@@ -111,6 +132,14 @@ class OpenAIConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class MistralAIConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class AnthropicConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class RunPodConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
@@ -138,6 +167,7 @@ class TTSConfig(BaseModel):
|
||||
class ChromaDB(BaseModel):
|
||||
instructor_device: str = "cpu"
|
||||
instructor_model: str = "default"
|
||||
openai_model: str = "text-embedding-3-small"
|
||||
embeddings: str = "default"
|
||||
|
||||
|
||||
@@ -149,8 +179,56 @@ class RecentScene(BaseModel):
|
||||
cover_image: Union[Asset, None] = None
|
||||
|
||||
|
||||
def gnerate_intro_scenes():
|
||||
"""
|
||||
When there are no recent scenes, generate from a set of introdutory scenes
|
||||
"""
|
||||
|
||||
scenes = [
|
||||
RecentScene(
|
||||
name="Simulation Suite",
|
||||
path=os.path.join(
|
||||
scenes_dir(), "simulation-suite", "simulation-suite.json"
|
||||
),
|
||||
filename="simulation-suite.json",
|
||||
date=datetime.datetime.now().isoformat(),
|
||||
cover_image=Asset(
|
||||
id="4b157dccac2ba71adb078a9d591f9900d6d62f3e86168a5e0e5e1e9faf6dc103",
|
||||
file_type="png",
|
||||
media_type="image/png",
|
||||
),
|
||||
),
|
||||
RecentScene(
|
||||
name="Infinity Quest",
|
||||
path=os.path.join(scenes_dir(), "infinity-quest", "infinity-quest.json"),
|
||||
filename="infinity-quest.json",
|
||||
date=datetime.datetime.now().isoformat(),
|
||||
cover_image=Asset(
|
||||
id="52b1388ed6f77a43981bd27e05df54f16e12ba8de1c48f4b9bbcb138fa7367df",
|
||||
file_type="png",
|
||||
media_type="image/png",
|
||||
),
|
||||
),
|
||||
RecentScene(
|
||||
name="Infinity Quest Dynamic Scenario",
|
||||
path=os.path.join(
|
||||
scenes_dir(), "infinity-quest-dynamic-scenario", "infinity-quest.json"
|
||||
),
|
||||
filename="infinity-quest.json",
|
||||
date=datetime.datetime.now().isoformat(),
|
||||
cover_image=Asset(
|
||||
id="e7c712a0b276342d5767ba23806b03912d10c7c4b82dd1eec0056611e2cd5404",
|
||||
file_type="png",
|
||||
media_type="image/png",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return scenes
|
||||
|
||||
|
||||
class RecentScenes(BaseModel):
|
||||
scenes: list[RecentScene] = pydantic.Field(default_factory=list)
|
||||
scenes: list[RecentScene] = pydantic.Field(default_factory=gnerate_intro_scenes)
|
||||
max_entries: int = 10
|
||||
|
||||
def push(self, scene: "Scene"):
|
||||
@@ -175,9 +253,11 @@ class RecentScenes(BaseModel):
|
||||
path=scene.full_path,
|
||||
filename=scene.filename,
|
||||
date=now.isoformat(),
|
||||
cover_image=scene.assets.assets[scene.assets.cover_image]
|
||||
if scene.assets.cover_image
|
||||
else None,
|
||||
cover_image=(
|
||||
scene.assets.assets[scene.assets.cover_image]
|
||||
if scene.assets.cover_image
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -192,8 +272,44 @@ class RecentScenes(BaseModel):
|
||||
self.scenes = [s for s in self.scenes if os.path.exists(s.path)]
|
||||
|
||||
|
||||
def validate_client_type(
|
||||
v: Any,
|
||||
handler: pydantic.ValidatorFunctionWrapHandler,
|
||||
info: pydantic.ValidationInfo,
|
||||
):
|
||||
# clients can specify a custom config model in
|
||||
# client_cls.config_cls so we need to convert the
|
||||
# client config to the correct model
|
||||
|
||||
# v is dict
|
||||
if isinstance(v, dict):
|
||||
client_cls = get_client_class(v.get("type"))
|
||||
if client_cls:
|
||||
config_cls = getattr(client_cls, "config_cls", None)
|
||||
if config_cls:
|
||||
return config_cls(**v)
|
||||
else:
|
||||
return handler(v)
|
||||
# v is Client instance
|
||||
elif isinstance(v, Client):
|
||||
client_cls = get_client_class(v.type)
|
||||
if client_cls:
|
||||
config_cls = getattr(client_cls, "config_cls", None)
|
||||
if config_cls:
|
||||
return config_cls(**v.model_dump())
|
||||
else:
|
||||
return handler(v)
|
||||
|
||||
|
||||
AnnotatedClient = Annotated[
|
||||
ClientType,
|
||||
pydantic.WrapValidator(validate_client_type),
|
||||
]
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
clients: Dict[str, Client] = {}
|
||||
clients: Dict[str, AnnotatedClient] = {}
|
||||
|
||||
game: Game
|
||||
|
||||
agents: Dict[str, Agent] = {}
|
||||
@@ -202,6 +318,10 @@ class Config(BaseModel):
|
||||
|
||||
openai: OpenAIConfig = OpenAIConfig()
|
||||
|
||||
mistralai: MistralAIConfig = MistralAIConfig()
|
||||
|
||||
anthropic: AnthropicConfig = AnthropicConfig()
|
||||
|
||||
runpod: RunPodConfig = RunPodConfig()
|
||||
|
||||
chromadb: ChromaDB = ChromaDB()
|
||||
@@ -240,7 +360,6 @@ def load_config(
|
||||
Should cache the config and only reload if the file modification time
|
||||
has changed since the last load
|
||||
"""
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
@@ -279,3 +398,44 @@ def save_config(config, file_path: str = "./config.yaml"):
|
||||
yaml.dump(config, file)
|
||||
|
||||
emit("config_saved", data=config)
|
||||
|
||||
|
||||
def cleanup():
|
||||
|
||||
log.info("cleaning up config")
|
||||
|
||||
config = load_config(as_model=True)
|
||||
|
||||
cleanup_removed_clients(config)
|
||||
cleanup_removed_agents(config)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
def cleanup_removed_clients(config: Config):
|
||||
"""
|
||||
Will remove any clients that are no longer present
|
||||
"""
|
||||
|
||||
if not config:
|
||||
return
|
||||
|
||||
for client_in_config in list(config.clients.keys()):
|
||||
client_config = config.clients[client_in_config]
|
||||
if not get_client_class(client_config.type):
|
||||
log.info("removing client from config", client=client_in_config)
|
||||
del config.clients[client_in_config]
|
||||
|
||||
|
||||
def cleanup_removed_agents(config: Config):
|
||||
"""
|
||||
Will remove any agents that are no longer present
|
||||
"""
|
||||
|
||||
if not config:
|
||||
return
|
||||
|
||||
for agent_in_config in list(config.agents.keys()):
|
||||
if not get_agent_class(agent_in_config):
|
||||
log.info("removing agent from config", agent=agent_in_config)
|
||||
del config.agents[agent_in_config]
|
||||
|
||||
@@ -38,6 +38,8 @@ class Emission:
|
||||
id: str = None
|
||||
details: str = None
|
||||
data: dict = None
|
||||
websocket_passthrough: bool = False
|
||||
meta: dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
def emit(
|
||||
@@ -125,8 +127,9 @@ class Receiver:
|
||||
def handle(self, emission: Emission):
|
||||
fn = getattr(self, f"handle_{emission.typ}", None)
|
||||
if not fn:
|
||||
return
|
||||
return False
|
||||
fn(emission)
|
||||
return True
|
||||
|
||||
def connect(self):
|
||||
for typ in handlers:
|
||||
|
||||
@@ -34,6 +34,8 @@ MessageEdited = signal("message_edited")
|
||||
|
||||
ConfigSaved = signal("config_saved")
|
||||
|
||||
ImageGenerated = signal("image_generated")
|
||||
|
||||
handlers = {
|
||||
"system": SystemMessage,
|
||||
"narrator": NarratorMessage,
|
||||
@@ -60,4 +62,5 @@ handlers = {
|
||||
"audio_queue": AudioQueue,
|
||||
"config_saved": ConfigSaved,
|
||||
"status": StatusMessage,
|
||||
"image_generated": ImageGenerated,
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ class Instructions(pydantic.BaseModel):
|
||||
|
||||
class Ops(pydantic.BaseModel):
|
||||
run_on_start: bool = False
|
||||
always_direct: bool = False
|
||||
|
||||
|
||||
class GameState(pydantic.BaseModel):
|
||||
@@ -95,8 +96,8 @@ class GameState(pydantic.BaseModel):
|
||||
def has_var(self, key: str) -> bool:
|
||||
return key in self.variables
|
||||
|
||||
def get_var(self, key: str) -> Any:
|
||||
return self.variables[key]
|
||||
def get_var(self, key: str, default: Any = None) -> Any:
|
||||
return self.variables.get(key, default)
|
||||
|
||||
def get_or_set_var(self, key: str, value: Any, commit: bool = False) -> Any:
|
||||
if not self.has_var(key):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Keep track of clients and agents
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
@@ -162,14 +163,9 @@ def emit_agent_status(cls, agent=None):
|
||||
data=cls.config_options(),
|
||||
)
|
||||
else:
|
||||
emit(
|
||||
"agent_status",
|
||||
message=agent.verbose_name or "",
|
||||
status=agent.status,
|
||||
id=agent.agent_type,
|
||||
details=agent.agent_details,
|
||||
data=cls.config_options(agent=agent),
|
||||
)
|
||||
asyncio.create_task(agent.emit_status())
|
||||
# loop = asyncio.get_event_loop()
|
||||
# loop.run_until_complete(agent.emit_status())
|
||||
|
||||
|
||||
def emit_agents_status(*args, **kwargs):
|
||||
@@ -177,9 +173,17 @@ def emit_agents_status(*args, **kwargs):
|
||||
Will emit status of all agents
|
||||
"""
|
||||
# log.debug("emit", type="agent status")
|
||||
for typ, cls in agents.AGENT_CLASSES.items():
|
||||
for typ, cls in sorted(
|
||||
agents.AGENT_CLASSES.items(), key=lambda x: x[1].verbose_name
|
||||
):
|
||||
agent = AGENTS.get(typ)
|
||||
emit_agent_status(cls, agent)
|
||||
|
||||
|
||||
handlers["request_agent_status"].connect(emit_agents_status)
|
||||
|
||||
|
||||
async def agent_ready_checks():
|
||||
for agent in AGENTS.values():
|
||||
if agent and agent.enabled:
|
||||
await agent.ready_check()
|
||||
|
||||
@@ -174,6 +174,9 @@ async def load_scene_from_data(
|
||||
scene.filename = None
|
||||
scene.goals = scene_data.get("goals", [])
|
||||
scene.immutable_save = scene_data.get("immutable_save", False)
|
||||
scene.experimental = scene_data.get("experimental", False)
|
||||
scene.help = scene_data.get("help", "")
|
||||
scene.restore_from = scene_data.get("restore_from", "")
|
||||
|
||||
# reset = True
|
||||
|
||||
@@ -240,9 +243,13 @@ async def load_scene_from_data(
|
||||
actor = Actor(character, agent)
|
||||
else:
|
||||
actor = Player(character, None)
|
||||
# Add the TestCharacter actor to the scene
|
||||
await scene.add_actor(actor)
|
||||
|
||||
# if there is nio player character, add the default player character
|
||||
|
||||
if not scene.get_player_character():
|
||||
await scene.add_actor(default_player_character())
|
||||
|
||||
# the scene has been saved before (since we just loaded it), so we set the saved flag to True
|
||||
# as long as the scene has a memory_id.
|
||||
scene.saved = "memory_id" in scene_data
|
||||
|
||||
@@ -205,9 +205,14 @@ class LoopedPrompt:
|
||||
self._current_item = None
|
||||
|
||||
|
||||
class JoinableList(list):
|
||||
|
||||
def join(self, separator: str = "\n"):
|
||||
return separator.join(self)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Prompt:
|
||||
|
||||
"""
|
||||
Base prompt class.
|
||||
"""
|
||||
@@ -355,6 +360,7 @@ class Prompt:
|
||||
env.globals["query_scene"] = self.query_scene
|
||||
env.globals["query_memory"] = self.query_memory
|
||||
env.globals["query_text"] = self.query_text
|
||||
env.globals["query_text_eval"] = self.query_text_eval
|
||||
env.globals["instruct_text"] = self.instruct_text
|
||||
env.globals["agent_action"] = self.agent_action
|
||||
env.globals["retrieve_memories"] = self.retrieve_memories
|
||||
@@ -364,6 +370,8 @@ class Prompt:
|
||||
env.globals["len"] = lambda x: len(x)
|
||||
env.globals["max"] = lambda x, y: max(x, y)
|
||||
env.globals["min"] = lambda x, y: min(x, y)
|
||||
env.globals["make_list"] = lambda: JoinableList()
|
||||
env.globals["make_dict"] = lambda: {}
|
||||
env.globals["count_tokens"] = lambda x: count_tokens(
|
||||
dedupe_string(x, debug=False)
|
||||
)
|
||||
@@ -372,6 +380,7 @@ class Prompt:
|
||||
env.globals["emit_system"] = lambda status, message: emit(
|
||||
"system", status=status, message=message
|
||||
)
|
||||
env.globals["emit_narrator"] = lambda message: emit("system", message=message)
|
||||
env.filters["condensed"] = condensed
|
||||
ctx.update(self.vars)
|
||||
|
||||
@@ -439,10 +448,14 @@ class Prompt:
|
||||
vars.update(kwargs)
|
||||
return Prompt.get(uid, vars=vars)
|
||||
|
||||
def render_and_request(self, prompt: "Prompt", kind: str = "create") -> str:
|
||||
def render_and_request(
|
||||
self, prompt: "Prompt", kind: str = "create", dedupe_enabled: bool = True
|
||||
) -> str:
|
||||
if not self.client:
|
||||
raise ValueError("Prompt has no client set.")
|
||||
|
||||
prompt.dedupe_enabled = dedupe_enabled
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(prompt.send(self.client, kind=kind))
|
||||
|
||||
@@ -483,9 +496,15 @@ class Prompt:
|
||||
]
|
||||
)
|
||||
|
||||
def query_text(self, query: str, text: str, as_question_answer: bool = True):
|
||||
def query_text(
|
||||
self,
|
||||
query: str,
|
||||
text: str,
|
||||
as_question_answer: bool = True,
|
||||
short: bool = False,
|
||||
):
|
||||
loop = asyncio.get_event_loop()
|
||||
summarizer = instance.get_agent("world_state")
|
||||
world_state = instance.get_agent("world_state")
|
||||
query = query.format(**self.vars)
|
||||
|
||||
if isinstance(text, list):
|
||||
@@ -493,7 +512,7 @@ class Prompt:
|
||||
|
||||
if not as_question_answer:
|
||||
return loop.run_until_complete(
|
||||
summarizer.analyze_text_and_answer_question(text, query)
|
||||
world_state.analyze_text_and_answer_question(text, query, short=short)
|
||||
)
|
||||
|
||||
return "\n".join(
|
||||
@@ -501,11 +520,18 @@ class Prompt:
|
||||
f"Question: {query}",
|
||||
f"Answer: "
|
||||
+ loop.run_until_complete(
|
||||
summarizer.analyze_text_and_answer_question(text, query)
|
||||
world_state.analyze_text_and_answer_question(
|
||||
text, query, short=short
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def query_text_eval(self, query: str, text: str):
|
||||
query = f"{query} Answer with a yes or no."
|
||||
response = self.query_text(query, text, as_question_answer=False, short=True)
|
||||
return response.strip().lower().startswith("y")
|
||||
|
||||
def query_memory(self, query: str, as_question_answer: bool = True, **kwargs):
|
||||
loop = asyncio.get_event_loop()
|
||||
memory = instance.get_agent("memory")
|
||||
@@ -551,14 +577,17 @@ class Prompt:
|
||||
world_state.analyze_text_and_extract_context("\n".join(lines), goal=goal)
|
||||
)
|
||||
|
||||
def agent_action(self, agent_name: str, action_name: str, **kwargs):
|
||||
def agent_action(self, agent_name: str, _action_name: str, **kwargs):
|
||||
loop = asyncio.get_event_loop()
|
||||
agent = instance.get_agent(agent_name)
|
||||
action = getattr(agent, action_name)
|
||||
action = getattr(agent, _action_name)
|
||||
return loop.run_until_complete(action(**kwargs))
|
||||
|
||||
def emit_status(self, status: str, message: str):
|
||||
emit("status", status=status, message=message)
|
||||
def emit_status(self, status: str, message: str, **kwargs):
|
||||
if kwargs:
|
||||
emit("status", status=status, message=message, data=kwargs)
|
||||
else:
|
||||
emit("status", status=status, message=message)
|
||||
|
||||
def set_prepared_response(self, response: str, prepend: str = ""):
|
||||
"""
|
||||
|
||||
@@ -37,9 +37,20 @@ Based on {{ talking_character.name}}'s example dialogue style, create a continua
|
||||
|
||||
You may chose to have {{ talking_character.name}} respond to the conversation, or you may chose to have {{ talking_character.name}} perform a new action that is in line with {{ talking_character.name}}'s character.
|
||||
|
||||
{% if scene.conversation_format == "movie_script" -%}
|
||||
The format is a movie script, so you should write the character's name in all caps followed by a line break and then the character's dialogue. For example:
|
||||
|
||||
CHARACTER NAME
|
||||
I'm so glad you're here.
|
||||
|
||||
Emotions and actions should be written in italics. For example:
|
||||
|
||||
CHARACTER NAME
|
||||
*smiles* I'm so glad you're here.
|
||||
{% else -%}
|
||||
Always contain actions in asterisks. For example, *{{ talking_character.name}} smiles*.
|
||||
Always contain dialogue in quotation marks. For example, {{ talking_character.name}}: "Hello!"
|
||||
|
||||
{% endif -%}
|
||||
{{ extra_instructions }}
|
||||
|
||||
{% if scene.count_character_messages(talking_character) >= 5 %}Use an informal and colloquial register with a conversational tone. Overall, {{ talking_character.name }}'s dialog is Informal, conversational, natural, and spontaneous, with a sense of immediacy.
|
||||
@@ -84,19 +95,30 @@ Always contain dialogue in quotation marks. For example, {{ talking_character.na
|
||||
<|SECTION:SCENE|>
|
||||
{% endblock -%}
|
||||
{% block scene_history -%}
|
||||
{% for scene_context in scene.context_history(budget=max_tokens-200-count_tokens(self.rendered_context()), min_dialogue=15, sections=False, keep_director=talking_character.name) -%}
|
||||
{{ scene_context }}
|
||||
{% set scene_context = scene.context_history(budget=max_tokens-200-count_tokens(self.rendered_context()), min_dialogue=15, sections=False, keep_director=talking_character.name) -%}
|
||||
{%- if talking_character.dialogue_instructions -%}
|
||||
{% set _ = scene_context.insert(-3, "# Internal acting instructions for "+talking_character.name+": "+talking_character.dialogue_instructions) %}
|
||||
{% endif -%}
|
||||
{% for scene_line in scene_context -%}
|
||||
{{ scene_line }}
|
||||
{% endfor %}
|
||||
{% endblock -%}
|
||||
<|CLOSE_SECTION|>
|
||||
{% if scene.count_character_messages(talking_character) < 5 %}Use an informal and colloquial register with a conversational tone. Overall, {{ talking_character.name }}'s dialog is Informal, conversational, natural, and spontaneous, with a sense of immediacy. Flesh out additional details by describing {{ talking_character.name }}'s actions and mannerisms within asterisks, e.g. *{{ talking_character.name }} smiles*.
|
||||
{% if scene.count_character_messages(talking_character) < 5 %}(Use an informal and colloquial register with a conversational tone. Overall, {{ talking_character.name }}'s dialog is Informal, conversational, natural, and spontaneous, with a sense of immediacy.)
|
||||
{% endif -%}
|
||||
{% if rerun_context and rerun_context.direction -%}
|
||||
{% if rerun_context.method == 'replace' -%}
|
||||
Final instructions for generating the next line of dialogue: {{ rerun_context.direction }}
|
||||
# Final instructions for generating the next line of dialogue: {{ rerun_context.direction }}
|
||||
{% elif rerun_context.method == 'edit' and rerun_context.message -%}
|
||||
Edit and respond with your changed version of the following line of dialogue: {{ rerun_context.message }}
|
||||
Requested changes: {{ rerun_context.direction }}
|
||||
# Edit and respond with your changed version of the following line of dialogue: {{ rerun_context.message|condensed }}
|
||||
|
||||
# Requested changes: {{ rerun_context.direction }}
|
||||
{% endif -%}
|
||||
{% endif -%}
|
||||
{{ bot_token}}{{ talking_character.name }}:{{ partial_message }}
|
||||
{% if scene.conversation_format == 'movie_script' -%}
|
||||
{{ bot_token }}{{ talking_character.name.upper() }}{% if partial_message %}
|
||||
{{ partial_message }}
|
||||
{% endif %}
|
||||
{% else -%}
|
||||
{{ bot_token }}{{ talking_character.name }}:{{ partial_message }}
|
||||
{% endif -%}
|
||||
@@ -0,0 +1,71 @@
|
||||
{% block rendered_context %}
|
||||
{% include "extra-context.jinja2" %}
|
||||
{% if character %}
|
||||
<|SECTION:CHARACTER|>
|
||||
{% if context_typ == 'character attribute' -%}
|
||||
{{ character.sheet_filtered(context_name) }}
|
||||
{% else -%}
|
||||
{{ character.sheet }}
|
||||
{% endif -%}
|
||||
<|CLOSE_SECTION|>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
<|SECTION:SCENE|>
|
||||
{% for scene_context in scene.context_history(budget=max_tokens-1024-count_tokens(self.rendered_context())) -%}
|
||||
{{ scene_context }}
|
||||
{% endfor %}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
{#- SET TASK ACTION -#}
|
||||
{% if not generation_context.original %}
|
||||
{%- set action_task = "Generate the" -%}
|
||||
{% else %}
|
||||
{%- set action_task = "Rewrite the existing" -%}
|
||||
Original {{ context_name }}: {{ generation_context.original }}
|
||||
{% endif %}
|
||||
{#- CHARACTER ATTRIBUTE -#}
|
||||
{% if context_typ == "character attribute" %}
|
||||
{{ action_task }} "{{ context_name }}" attribute for {{ character.name }}. This must be a general description and not a continuation of the current narrative.
|
||||
{#- CHARACTER DETAIL -#}
|
||||
{% elif context_typ == "character detail" %}
|
||||
{% if context_name.endswith("?") -%}
|
||||
{{ action_task }} answer to "{{ context_name }}" for {{ character.name }}. This must be a general description and not a continuation of the current narrative.
|
||||
{% else -%}
|
||||
{{ action_task }} "{{ context_name }}" detail for {{ character.name }}. This must be a general description and not a continuation of the current narrative. Use paragraphs to separate different details.
|
||||
{% endif -%}
|
||||
Use a simple, easy to read writing format.
|
||||
{#- CHARACTER EXAMPLE DIALOGUE -#}
|
||||
{% elif context_typ == "character dialogue" %}
|
||||
Generate a new line of example dialogue for {{ character.name }}.
|
||||
|
||||
Exisiting Dialogue Examples:
|
||||
{% for line in character.example_dialogue %}
|
||||
{{ line }}
|
||||
{% endfor %}
|
||||
|
||||
You must only respond with the generated dialogue example.
|
||||
Always contain actions in asterisks. For example, *{{ character.name}} smiles*.
|
||||
Always contain dialogue in quotation marks. For example, {{ character.name}}: "Hello!"
|
||||
|
||||
{%- if character.dialogue_instructions -%}
|
||||
Dialogue instructions for {{ character.name }}: {{ character.dialogue_instructions }}
|
||||
{% endif -%}
|
||||
{#- GENERAL CONTEXT -#}
|
||||
{% else %}
|
||||
{% if context_name.endswith("?") -%}
|
||||
{{ action_task }} answer to the question "{{ context_name }}". This must be a general description and not a continuation of the current narrative.
|
||||
{%- else -%}
|
||||
{{ action_task }} new narrative content for {{ context_name }}
|
||||
Use a simple, easy to read writing format.
|
||||
{%- endif -%}
|
||||
{% endif %}
|
||||
{% if generation_context.instructions %}Additional instructions: {{ generation_context.instructions }}{% endif %}
|
||||
<|CLOSE_SECTION|>
|
||||
{{ bot_token }}
|
||||
{%- if context_typ == 'character attribute' -%}
|
||||
{{ character.name }}'s {{ context_name }}:
|
||||
{%- elif context_typ == 'character dialogue' -%}
|
||||
{{ character.name }}:
|
||||
{%- else -%}
|
||||
{{ context_name }}:
|
||||
{%- endif -%}
|
||||
@@ -0,0 +1,21 @@
|
||||
{% block rendered_context %}
|
||||
{% include "extra-context.jinja2" %}
|
||||
{% endblock %}
|
||||
<|SECTION:SCENE|>
|
||||
{% for scene_context in scene.context_history(budget=max_tokens-1024-count_tokens(self.rendered_context())) -%}
|
||||
{{ scene_context }}
|
||||
{% endfor %}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
Determine character name based on the following sentence: {{ character_name }}
|
||||
|
||||
{% if not allowed_names -%}
|
||||
If the character already has a distinct name, respond with the character's name.
|
||||
If the name is currently a description, give the character a distinct name.
|
||||
If we don't know the character's actual name, you must decide one.
|
||||
YOU MUST ONLY RESPOND WITH THE CHARACTER NAME, NOTHING ELSE.
|
||||
{% else %}
|
||||
Pick the most fitting name from the following list: {{ allowed_names|join(', ') }}. If none of the names fit, respond with the most accurate name based on the sentence.
|
||||
{%- endif %}
|
||||
<|CLOSE_SECTION|>
|
||||
{{ bot_token }}The character's name is "
|
||||
29
src/talemate/prompts/templates/creator/extra-context.jinja2
Normal file
@@ -0,0 +1,29 @@
|
||||
Scenario Premise:
|
||||
{{ scene.description }}
|
||||
|
||||
Content Context: This is a specific scene from {{ scene.context }}
|
||||
|
||||
{% block rendered_context_static %}
|
||||
{# GENERAL REINFORCEMENTS #}
|
||||
{% set general_reinforcements = scene.world_state.filter_reinforcements(insert=['all-context']) %}
|
||||
{%- for reinforce in general_reinforcements %}
|
||||
{{ reinforce.as_context_line|condensed }}
|
||||
|
||||
{% endfor %}
|
||||
{# END GENERAL REINFORCEMENTS #}
|
||||
{# ACTIVE PINS #}
|
||||
{%- for pin in scene.active_pins %}
|
||||
{{ pin.time_aware_text|condensed }}
|
||||
|
||||
{% endfor %}
|
||||
{# END ACTIVE PINS #}
|
||||
{% endblock %}
|
||||
|
||||
{# MEMORY #}
|
||||
{%- if memory_query %}
|
||||
{%- for memory in query_memory(memory_query, as_question_answer=False, max_tokens=max_tokens-500-count_tokens(self.rendered_context_static()), iterate=10) -%}
|
||||
{{ memory|condensed }}
|
||||
|
||||
{% endfor -%}
|
||||
{% endif -%}
|
||||
{# END MEMORY #}
|
||||
@@ -14,6 +14,8 @@ Player Character: {{ player_character.name }}
|
||||
{% endfor %}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
YOU MUST WRITE FROM THE PERSPECTIVE OF THE NARRATOR.
|
||||
|
||||
Continue the current dialogue by narrating the progression of the scene.
|
||||
|
||||
If the scene is over, narrate the beginning of the next scene.
|
||||
@@ -28,6 +30,8 @@ Use an informal and colloquial register with a conversational tone. Overall, the
|
||||
|
||||
Narration style should be that of a 90s point and click adventure game. You are omniscient and can describe the scene in detail.
|
||||
|
||||
YOU MUST WRITE FROM THE PERSPECTIVE OF THE NARRATOR.
|
||||
|
||||
Only generate new narration. Avoid including any character's internal thoughts or dialogue.
|
||||
|
||||
{% if narrative_direction %}
|
||||
@@ -36,5 +40,4 @@ Directions for new narration: {{ narrative_direction }}
|
||||
|
||||
Write 2 to 4 sentences. {{ extra_instructions }}
|
||||
{% include "rerun-context.jinja2" -%}
|
||||
<|CLOSE_SECTION|>
|
||||
{{ set_prepared_response("*") }}
|
||||
<|CLOSE_SECTION|>
|
||||
@@ -3,6 +3,11 @@
|
||||
{%- with memory_query=query -%}
|
||||
{% include "extra-context.jinja2" %}
|
||||
{% endwith -%}
|
||||
{% set related_character = scene.parse_character_from_line(query) -%}
|
||||
{% if related_character -%}
|
||||
<|SECTION:{{ related_character.name|upper }}|>
|
||||
{{ related_character.sheet}}
|
||||
{% endif %}
|
||||
<|CLOSE_SECTION|>
|
||||
{% endblock %}
|
||||
{% set scene_history=scene.context_history(budget=max_tokens-200-count_tokens(self.rendered_context())) %}
|
||||
|
||||
20
src/talemate/prompts/templates/narrator/paraphrase.jinja2
Normal file
@@ -0,0 +1,20 @@
|
||||
{% block rendered_context -%}
|
||||
<|SECTION:CONTEXT|>
|
||||
{% include "extra-context.jinja2" %}
|
||||
<|CLOSE_SECTION|>
|
||||
{% endblock -%}
|
||||
<|SECTION:SCENE|>
|
||||
{% for scene_context in scene.context_history(budget=max_tokens-300-count_tokens(self.rendered_context())) -%}
|
||||
{{ scene_context }}
|
||||
{% endfor %}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:TASK|>
|
||||
Paraphrase the following text to fit the narrative thus far. Keep the information and the meaning the same, but change the wording and sentence structure.
|
||||
|
||||
Text to paraphrase:
|
||||
|
||||
"{{ text }}"
|
||||
|
||||
{{ extra_instructions }}
|
||||
{% include "rerun-context.jinja2" -%}
|
||||
<|CLOSE_SECTION|>
|
||||
29
src/talemate/prompts/templates/visual/extra-context.jinja2
Normal file
@@ -0,0 +1,29 @@
|
||||
Scenario Premise:
|
||||
{{ scene.description }}
|
||||
|
||||
Content Context: This is a specific scene from {{ scene.context }}
|
||||
|
||||
{% block rendered_context_static %}
|
||||
{# GENERAL REINFORCEMENTS #}
|
||||
{% set general_reinforcements = scene.world_state.filter_reinforcements(insert=['all-context']) %}
|
||||
{%- for reinforce in general_reinforcements %}
|
||||
{{ reinforce.as_context_line|condensed }}
|
||||
|
||||
{% endfor %}
|
||||
{# END GENERAL REINFORCEMENTS #}
|
||||
{# ACTIVE PINS #}
|
||||
{%- for pin in scene.active_pins %}
|
||||
{{ pin.time_aware_text|condensed }}
|
||||
|
||||
{% endfor %}
|
||||
{# END ACTIVE PINS #}
|
||||
{% endblock %}
|
||||
|
||||
{# MEMORY #}
|
||||
{%- if memory_query %}
|
||||
{%- for memory in query_memory(memory_query, as_question_answer=False, max_tokens=max_tokens-500-count_tokens(self.rendered_context_static()), iterate=10) -%}
|
||||
{{ memory|condensed }}
|
||||
|
||||
{% endfor -%}
|
||||
{% endif -%}
|
||||
{# END MEMORY #}
|
||||
@@ -0,0 +1,29 @@
|
||||
{{ query_scene("What is "+character.name+"'s age, race, and physical appearance?", full_context) }}
|
||||
|
||||
{{ query_scene("What clothes is "+character.name+" currently wearing? Provide a detailed description.", full_context) }}
|
||||
|
||||
{{ query_scene("What is "+character.name+"'s current scene description?", full_context) }}
|
||||
|
||||
{{ query_scene("Where is "+character.name+" currently at? Briefly describe the environment and provide genre context.", full_context) }}
|
||||
{% set emotion = scene.world_state.character_emotion(character.name) %}
|
||||
{% if emotion %}{{ character.name }}'s current emotion: {{ emotion }}{% endif %}
|
||||
<|SECTION:TASK|>
|
||||
{% if instructions %}Requested Image: {{ instructions }}{% endif %}
|
||||
|
||||
Describe the scene to the painter to ensure he will capture all the important details when drawing a dynamic and truthful image of {{ character.name }}.
|
||||
|
||||
Include details about the {{ character.name }}'s appearance exactly as they are, and {{ character.name }}'s current pose.
|
||||
Include a description of the environment.
|
||||
|
||||
THE IMAGE MUST ONLY INCLUDE {{ character.name }} EXCLUDE ALL OTHER CHARACTERS.
|
||||
YOU MUST ONLY DESCRIBE WHAT IS CURRENTLY VISIBLE IN THE SCENE.
|
||||
|
||||
Required information: name, age, race, gender, physique, expression, pose, clothes/equipment, hair style, hair color, skin color, eyes, scars, tattoos, piercings, a fitting color scheme and any other relevant details.
|
||||
|
||||
You must provide your answer as a comma delimited list of keywords.
|
||||
Keywords should be ordered: physical appearance, emotion, action, environment, color scheme.
|
||||
You must provide many keywords to describe the character and the environment in great detail.
|
||||
Your answer must be suitable as a stable-diffusion image generation prompt.
|
||||
You must avoid negating of keywords, omit things entirely that aren't there. For example instead of saying "no scars", just dont include the keyword scars at all.
|
||||
<|CLOSE_SECTION|>
|
||||
{{ set_prepared_response(character.name+",")}}
|
||||