WIP: Prep 0.21.0 (#83)
* cleanup * refactor clean_dialogue * prompt fixes * prompt fixes * conversation format types - movie script and chat (legacy) * stopping strings updated * mistral.ai client * prompt tweaks * mistral client return token counts * anthropic client * archive history emits whole object so we can inspectr time stamps * show timestamp in history dialog * openai compat fixes to stop trying to coerce openai url path schema and to never attempt to retrieve the model name automatically, hopefully improving compatibility with the various openai api implementations across the board * openai compat client let api control prompt template via config option * fix custom client configs and implement max backscroll * fix backscroll limit * remove debug message * prep 0.21.0 * include model name in prompt template selection label * use tabs for side nav in app config modal * readme / docs * fix issue where "No API key set" could be persisted as the selected model name to the config * deepinfra example * linting
92
README.md
@@ -7,16 +7,21 @@ Roleplay with AI with a focus on strong narration and consistent world and game
|
||||
|||
|
||||
|||
|
||||
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio. 0.18.0 also adds support for generic OpenAI api implementations, but generation quality on that will vary.**
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, Anthropic, mistral.ai, self-hosted text-generation-webui and LMStudio. 0.18.0 also adds support for generic OpenAI api implementations, but generation quality on that will vary.**
|
||||
|
||||
This means you need to either have:
|
||||
- an [OpenAI](https://platform.openai.com/overview) api key
|
||||
- setup local (or remote via runpod) LLM inference via:
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
- Any other OpenAI api implementation that implements the v1/completions endpoint
|
||||
- tested llamacpp with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
Officially supported APIs:
|
||||
- [OpenAI](https://platform.openai.com/overview)
|
||||
- [Anthropic](https://www.anthropic.com/)
|
||||
- [mistral.ai](https://mistral.ai/)
|
||||
|
||||
Officially supported self-hosted APIs:
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (local or with runpod support)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
|
||||
Generic OpenAI api implementations (tested and confirmed working):
|
||||
- [DeepInfra](https://deepinfra.com/) - see [instructions](https://github.com/vegu-ai/talemate/issues/78#issuecomment-1986884304)
|
||||
- [llamacpp](https://github.com/ggerganov/llama.cpp) with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
|
||||
## Current features
|
||||
|
||||
@@ -78,8 +83,9 @@ Please read the documents in the `docs` folder for more advanced configuration a
|
||||
- [Installation](#installation)
|
||||
- [Connecting to an LLM](#connecting-to-an-llm)
|
||||
- [Text-generation-webui](#text-generation-webui)
|
||||
- [Recommended Models](#recommended-models)
|
||||
- [OpenAI](#openai)
|
||||
- [Recommended Models](#recommended-models)
|
||||
- [OpenAI / mistral.ai / Anthropic](#openai)
|
||||
- [DeepInfra via OpenAI Compatible client](#deepinfra-via-openai-compatible-client)
|
||||
- [Ready to go](#ready-to-go)
|
||||
- [Load the introductory scenario "Infinity Quest"](#load-the-introductory-scenario-infinity-quest)
|
||||
- [Loading character cards](#loading-character-cards)
|
||||
@@ -118,43 +124,67 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
|
||||
1. Start the backend: `python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
|
||||
1. Open a new terminal, navigate to the `talemate_frontend` directory, and start the frontend server by running `npm run serve`.
|
||||
|
||||
## Connecting to an LLM
|
||||
# Connecting to an LLM
|
||||
|
||||
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
|
||||
|
||||

|
||||
|
||||
### Text-generation-webui
|
||||

|
||||
|
||||
## Text-generation-webui
|
||||
|
||||
> :warning: As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
|
||||
|
||||
In the modal if you're planning to connect to text-generation-webui, you can likely leave everything as is and just click Save.
|
||||
|
||||

|
||||

|
||||
|
||||
### Specifying the correct prompt template
|
||||
|
||||
#### Recommended Models
|
||||
For good results it is **vital** that the correct prompt template is specified for whichever model you have loaded.
|
||||
|
||||
As of 2024.02.06 my personal regular drivers (the ones i test with) are:
|
||||
Talemate does come with a set of pre-defined templates for some popular models, but going forward, due to the sheet number of models released every day, understanding and specifying the correct prompt template is something you should familiarize yourself with.
|
||||
|
||||
If the text-gen-webui client shows a yellow triangle next to it, it means that the prompt template is not set, and it is currently using the default `VICUNA` style prompt template.
|
||||
|
||||

|
||||
|
||||
Click the two cogwheels to the right of the triangle to open the client settings.
|
||||
|
||||

|
||||
|
||||
You can first try by clicking the `DETERMINE VIA HUGGINGFACE` button, depending on the model's README file, it may be able to determine the correct prompt template for you. (basically the readme needs to contain an example of the template)
|
||||
|
||||
If that doesn't work, you can manually select the prompt template from the dropdown.
|
||||
|
||||
In the case for `bartowski_Nous-Hermes-2-Mistral-7B-DPO-exl2_8_0` that is `ChatML` - select it from the dropdown and click `Save`.
|
||||
|
||||

|
||||
|
||||
### Recommended Models
|
||||
|
||||
As of 2024.03.07 my personal regular drivers (the ones i test with) are:
|
||||
|
||||
- Kunoichi-7B
|
||||
- sparsetral-16x7B
|
||||
- Nous-Hermes-2-SOLAR-10.7B
|
||||
- Nous-Hermes-2-Mistral-7B-DPO
|
||||
- brucethemoose_Yi-34B-200K-RPMerge
|
||||
- dolphin-2.7-mixtral-8x7b
|
||||
- rAIfle_Verdict-8x7B
|
||||
- Mixtral-8x7B-instruct
|
||||
- GPT-3.5-turbo 0125
|
||||
- GPT-4-turbo 0116
|
||||
|
||||
That said, any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
|
||||
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/18yp9u4/llm_comparisontest_api_edition_gpt4_vs_gemini_vs/
|
||||
|
||||
### OpenAI
|
||||
## OpenAI / mistral.ai / Anthropic
|
||||
|
||||
The setup is the same for all three, the example below is for OpenAI.
|
||||
|
||||
If you want to add an OpenAI client, just change the client type and select the apropriate model.
|
||||
|
||||

|
||||

|
||||
|
||||
If you are setting this up for the first time, you should now see the client, but it will have a red dot next to it, stating that it requires an API key.
|
||||
|
||||
@@ -162,17 +192,33 @@ If you are setting this up for the first time, you should now see the client, bu
|
||||
|
||||
Click the `SET API KEY` button. This will open a modal where you can enter your API key.
|
||||
|
||||

|
||||

|
||||
|
||||
Click `Save` and after a moment the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
## DeepInfra via OpenAI Compatible client
|
||||
|
||||
You can use the OpenAI compatible client to connect to [DeepInfra](https://deepinfra.com/).
|
||||
|
||||

|
||||
|
||||
```
|
||||
API URL: https://api.deepinfra.com/v1/openai
|
||||
```
|
||||
|
||||
Models on DeepInfra that work well with Talemate:
|
||||
|
||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://deepinfra.com/mistralai/Mixtral-8x7B-Instruct-v0.1) (max context 32k, 8k recommended)
|
||||
- [cognitivecomputations/dolphin-2.6-mixtral-8x7b](https://deepinfra.com/cognitivecomputations/dolphin-2.6-mixtral-8x7b) (max context 32k, 8k recommended)
|
||||
- [lizpreciatior/lzlv_70b_fp16_hf](https://deepinfra.com/lizpreciatior/lzlv_70b_fp16_hf) (max context 4k)
|
||||
|
||||
## Ready to go
|
||||
|
||||
You will know you are good to go when the client and all the agents have a green dot next to them.
|
||||
|
||||

|
||||

|
||||
|
||||
## Load the introductory scenario "Infinity Quest"
|
||||
|
||||
|
||||
BIN
docs/img/0.21.0/deepinfra-setup.png
Normal file
|
After Width: | Height: | Size: 56 KiB |
BIN
docs/img/0.21.0/no-clients.png
Normal file
|
After Width: | Height: | Size: 7.1 KiB |
BIN
docs/img/0.21.0/openai-add-api-key.png
Normal file
|
After Width: | Height: | Size: 35 KiB |
BIN
docs/img/0.21.0/openai-setup.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
docs/img/0.21.0/prompt-template-default.png
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
docs/img/0.21.0/ready-to-go.png
Normal file
|
After Width: | Height: | Size: 43 KiB |
BIN
docs/img/0.21.0/select-prompt-template.png
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
docs/img/0.21.0/selected-prompt-template.png
Normal file
|
After Width: | Height: | Size: 49 KiB |
BIN
docs/img/0.21.0/text-gen-webui-setup.png
Normal file
|
After Width: | Height: | Size: 26 KiB |
1153
poetry.lock
generated
@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.20.0"
|
||||
version = "0.21.0"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
@@ -39,6 +39,7 @@ thefuzz = ">=0.20.0"
|
||||
tiktoken = ">=0.5.1"
|
||||
nltk = ">=3.8.1"
|
||||
huggingface-hub = ">=0.20.2"
|
||||
anthropic = ">=0.19.1"
|
||||
|
||||
# ChromaDB
|
||||
chromadb = ">=0.4.17,<1"
|
||||
|
||||
@@ -20,6 +20,8 @@ You must at least call one of the following functions:
|
||||
- end_simulation
|
||||
- answer_question
|
||||
|
||||
`add_ai_character` and `change_ai_character` are exclusive if they are targeting the same character.
|
||||
|
||||
Set the player persona at the beginning of a new simulation or if the player requests a change.
|
||||
|
||||
Only end the simulation if the player requests it explicitly.
|
||||
|
||||
@@ -126,7 +126,7 @@
|
||||
{% set _ = game_state.set_var("instr.has_issued_instructions", "yes", commit=False) %}
|
||||
{% set _ = emit_status("busy", "Simulation suite altering environment.", as_scene_message=True) %}
|
||||
{% set update_world_state = True %}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="progress_story", narrative_direction="The computer calls the following functions:\n"+processed.join("\n")+"\nand the simulation adjusts the environment according to the user's wishes. Write the narrative that describes the changes.", emit_message=True) %}
|
||||
{% set _ = agent_action("narrator", "action_to_narration", action_name="progress_story", narrative_direction="The computer calls the following functions:\n"+processed.join("\n")+"\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up.", emit_message=True) %}
|
||||
{% endif %}
|
||||
|
||||
{% elif not game_state.has_var("instr.simulation_started") %}
|
||||
|
||||
@@ -2,4 +2,4 @@ from .agents import Agent
|
||||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.20.0"
|
||||
VERSION = "0.21.0"
|
||||
|
||||
@@ -78,9 +78,18 @@ class ConversationAgent(Agent):
|
||||
self.actions = {
|
||||
"generation_override": AgentAction(
|
||||
enabled=True,
|
||||
label="Generation Override",
|
||||
description="Override generation parameters",
|
||||
label="Generation Settings",
|
||||
config={
|
||||
"format": AgentActionConfig(
|
||||
type="text",
|
||||
label="Format",
|
||||
description="The format of the dialogue, as seen by the AI.",
|
||||
choices=[
|
||||
{"label": "Movie Script", "value": "movie_script"},
|
||||
{"label": "Chat (legacy)", "value": "chat"},
|
||||
],
|
||||
value="chat",
|
||||
),
|
||||
"length": AgentActionConfig(
|
||||
type="number",
|
||||
label="Generation Length (tokens)",
|
||||
@@ -166,6 +175,12 @@ class ConversationAgent(Agent):
|
||||
),
|
||||
}
|
||||
|
||||
@property
|
||||
def conversation_format(self):
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["format"].value
|
||||
return "movie_script"
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
@@ -605,14 +620,20 @@ class ConversationAgent(Agent):
|
||||
|
||||
result = result.replace(" :", ":")
|
||||
|
||||
total_result = total_result.split("#")[0]
|
||||
total_result = total_result.split("#")[0].strip()
|
||||
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
|
||||
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
# Remove "{character.name}:" - all occurences
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
import talemate.client.runpod
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.anthropic import AnthropicClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
|
||||
224
src/talemate/client/anthropic.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from anthropic import AsyncAnthropic, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"AnthropicClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "claude-3-sonnet-20240229"
|
||||
|
||||
|
||||
@register()
|
||||
class AnthropicClient(ClientBase):
|
||||
"""
|
||||
Anthropic client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "anthropic"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Anthropic"
|
||||
title: str = "Anthropic"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="claude-3-sonnet-20240229", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def anthropic_api_key(self):
|
||||
return self.config.get("anthropic", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.anthropic_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"anthropic_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.anthropic_api_key:
|
||||
self.client = AsyncAnthropic(api_key="sk-1111")
|
||||
log.error("No anthropic API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "claude-3-opus-20240229"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncAnthropic(api_key=self.anthropic_api_key)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"anthropic set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.output_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.input_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.anthropic_api_key:
|
||||
raise Exception("No anthropic API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system_message,
|
||||
messages=[human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
log.debug("generated response", response=response.content)
|
||||
|
||||
response = response.content[0].text
|
||||
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="anthropic API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -363,6 +363,11 @@ class ClientBase:
|
||||
f"{character}:" for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
dialog_stopping_strings += [
|
||||
f"{character.upper()}\n"
|
||||
for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
if "extra_stopping_strings" in parameters:
|
||||
parameters["extra_stopping_strings"] += dialog_stopping_strings
|
||||
else:
|
||||
@@ -405,6 +410,9 @@ class ClientBase:
|
||||
"""
|
||||
|
||||
try:
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
|
||||
self.emit_status(processing=True)
|
||||
await self.status()
|
||||
|
||||
@@ -452,8 +460,9 @@ class ClientBase:
|
||||
kind=kind,
|
||||
prompt=finalized_prompt,
|
||||
response=response,
|
||||
prompt_tokens=token_length,
|
||||
response_tokens=self.count_tokens(response),
|
||||
prompt_tokens=self._returned_prompt_tokens or token_length,
|
||||
response_tokens=self._returned_response_tokens
|
||||
or self.count_tokens(response),
|
||||
agent_stack=agent_context.agent_stack if agent_context else [],
|
||||
client_name=self.name,
|
||||
client_type=self.client_type,
|
||||
@@ -465,6 +474,8 @@ class ClientBase:
|
||||
return response
|
||||
finally:
|
||||
self.emit_status(processing=False)
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
|
||||
async def auto_break_repetition(
|
||||
self,
|
||||
|
||||
232
src/talemate/client/mistral.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import json
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import tiktoken
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
__all__ = [
|
||||
"MistralAIClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
# Edit this to add new models / remove old models
|
||||
SUPPORTED_MODELS = [
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "open-mixtral-8x7b"
|
||||
|
||||
|
||||
@register()
|
||||
class MistralAIClient(ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
|
||||
client_type = "mistral"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "MistralAI"
|
||||
title: str = "MistralAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = SUPPORTED_MODELS
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="open-mixtral-8x7b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def mistralai_api_key(self):
|
||||
return self.config.get("mistralai", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if self.mistralai_api_key:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
status = "error"
|
||||
model_name = "No API key set"
|
||||
error_action = ErrorAction(
|
||||
title="Set API Key",
|
||||
action_name="openAppConfig",
|
||||
icon="mdi-key-variant",
|
||||
arguments=[
|
||||
"application",
|
||||
"mistralai_api",
|
||||
],
|
||||
)
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.mistralai_api_key:
|
||||
self.client = AsyncOpenAI(api_key="sk-1111")
|
||||
log.error("No mistral.ai API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "open-mixtral-8x7b"
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.mistralai_api_key, base_url="https://api.mistral.ai/v1/"
|
||||
)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info(
|
||||
"mistral.ai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def response_tokens(self, response: str):
|
||||
return response.usage.completion_tokens
|
||||
|
||||
def prompt_tokens(self, response: str):
|
||||
return response.usage.prompt_tokens
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
keys = list(parameters.keys())
|
||||
valid_keys = ["temperature", "top_p", "max_tokens"]
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.mistralai_api_key:
|
||||
raise Exception("No mistral.ai API key set")
|
||||
|
||||
right = None
|
||||
expected_response = None
|
||||
try:
|
||||
_, right = prompt.split("\nStart your response with: ")
|
||||
expected_response = right.strip()
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
prompt=prompt[:128] + " ...",
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[system_message, human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
self._returned_prompt_tokens = self.prompt_tokens(response)
|
||||
self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
# so we strip that out if the expected response is a json object
|
||||
if expected_response and expected_response.startswith("{"):
|
||||
if response.startswith("```json") and response.endswith("```"):
|
||||
response = response[7:-3].strip()
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="mistral.ai API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
@@ -1,10 +1,12 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
import urllib
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.openai_compat")
|
||||
|
||||
@@ -16,12 +18,18 @@ class Defaults(pydantic.BaseModel):
|
||||
api_key: str = ""
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
api_handles_prompt_template: bool = False
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
api_handles_prompt_template: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAICompatibleClient(ClientBase):
|
||||
client_type = "openai_compat"
|
||||
conversation_retries = 5
|
||||
config_cls = ClientConfig
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title: str = "OpenAI Compatible API"
|
||||
@@ -30,10 +38,22 @@ class OpenAICompatibleClient(ClientBase):
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = {
|
||||
"api_handles_prompt_template": ExtraField(
|
||||
name="api_handles_prompt_template",
|
||||
type="bool",
|
||||
label="API Handles Prompt Template",
|
||||
required=False,
|
||||
description="The API handles the prompt template, meaning your choice in the UI for the prompt template below will be ignored.",
|
||||
)
|
||||
}
|
||||
|
||||
def __init__(self, model=None, api_key=None, **kwargs):
|
||||
def __init__(
|
||||
self, model=None, api_key=None, api_handles_prompt_template=False, **kwargs
|
||||
):
|
||||
self.model_name = model
|
||||
self.api_key = api_key
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@@ -42,11 +62,10 @@ class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
url = self.api_url
|
||||
if not url.endswith("/v1"):
|
||||
url = url + "/v1"
|
||||
|
||||
self.client = AsyncOpenAI(base_url=url, api_key=self.api_key)
|
||||
self.model_name = (
|
||||
kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
@@ -63,26 +82,27 @@ class OpenAICompatibleClient(ClientBase):
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
|
||||
log.debug(
|
||||
"IS API HANDLING PROMPT TEMPLATE",
|
||||
api_handles_prompt_template=self.api_handles_prompt_template,
|
||||
)
|
||||
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nStart your response with: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
return prompt
|
||||
|
||||
async def get_model_name(self):
|
||||
try:
|
||||
model_name = await super().get_model_name()
|
||||
except NotFoundError as e:
|
||||
# api does not implement model listing
|
||||
return self.model_name
|
||||
except Exception as e:
|
||||
self.log.error("get_model_name error", e=e)
|
||||
return self.model_name
|
||||
|
||||
# model name may be a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
is_filepath = "/" in model_name
|
||||
is_filepath_windows = "\\" in model_name
|
||||
|
||||
if is_filepath or is_filepath_windows:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
return self.model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
@@ -120,6 +140,8 @@ class OpenAICompatibleClient(ClientBase):
|
||||
)
|
||||
if "api_key" in kwargs:
|
||||
self.api_auth = kwargs["api_key"]
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
|
||||
log.warning("reconfigure", kwargs=kwargs)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, TypeVar, Union
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, TypeVar, Union, Any
|
||||
from typing_extensions import Annotated
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
@@ -81,6 +83,7 @@ class GamePlayerCharacter(BaseModel):
|
||||
class General(BaseModel):
|
||||
auto_save: bool = True
|
||||
auto_progress: bool = True
|
||||
max_backscroll: int = 512
|
||||
|
||||
|
||||
class StateReinforcementTemplate(BaseModel):
|
||||
@@ -129,6 +132,14 @@ class OpenAIConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class MistralAIConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class AnthropicConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class RunPodConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
@@ -261,8 +272,43 @@ class RecentScenes(BaseModel):
|
||||
self.scenes = [s for s in self.scenes if os.path.exists(s.path)]
|
||||
|
||||
|
||||
def validate_client_type(
|
||||
v: Any,
|
||||
handler: pydantic.ValidatorFunctionWrapHandler,
|
||||
info: pydantic.ValidationInfo,
|
||||
):
|
||||
# clients can specify a custom config model in
|
||||
# client_cls.config_cls so we need to convert the
|
||||
# client config to the correct model
|
||||
|
||||
# v is dict
|
||||
if isinstance(v, dict):
|
||||
client_cls = get_client_class(v.get("type"))
|
||||
if client_cls:
|
||||
config_cls = getattr(client_cls, "config_cls", None)
|
||||
if config_cls:
|
||||
return config_cls(**v)
|
||||
else:
|
||||
return handler(v)
|
||||
# v is Client instance
|
||||
elif isinstance(v, Client):
|
||||
client_cls = get_client_class(v.type)
|
||||
if client_cls:
|
||||
config_cls = getattr(client_cls, "config_cls", None)
|
||||
if config_cls:
|
||||
return config_cls(**v.model_dump())
|
||||
else:
|
||||
return handler(v)
|
||||
|
||||
|
||||
AnnotatedClient = Annotated[
|
||||
ClientType,
|
||||
pydantic.WrapValidator(validate_client_type),
|
||||
]
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
clients: Dict[str, ClientType] = {}
|
||||
clients: Dict[str, AnnotatedClient] = {}
|
||||
|
||||
game: Game
|
||||
|
||||
@@ -272,6 +318,10 @@ class Config(BaseModel):
|
||||
|
||||
openai: OpenAIConfig = OpenAIConfig()
|
||||
|
||||
mistralai: MistralAIConfig = MistralAIConfig()
|
||||
|
||||
anthropic: AnthropicConfig = AnthropicConfig()
|
||||
|
||||
runpod: RunPodConfig = RunPodConfig()
|
||||
|
||||
chromadb: ChromaDB = ChromaDB()
|
||||
@@ -301,19 +351,6 @@ class SceneAssetUpload(BaseModel):
|
||||
content: str = None
|
||||
|
||||
|
||||
def prepare_client_config(clients: dict) -> dict:
|
||||
# client's can specify a custom config model in
|
||||
# client_cls.config_cls so we need to convert the
|
||||
# client config to the correct model
|
||||
|
||||
for client_name, client_config in clients.items():
|
||||
client_cls = get_client_class(client_config.get("type"))
|
||||
if client_cls:
|
||||
config_cls = getattr(client_cls, "config_cls", None)
|
||||
if config_cls:
|
||||
clients[client_name] = config_cls(**client_config)
|
||||
|
||||
|
||||
def load_config(
|
||||
file_path: str = "./config.yaml", as_model: bool = False
|
||||
) -> Union[dict, Config]:
|
||||
@@ -323,12 +360,10 @@ def load_config(
|
||||
Should cache the config and only reload if the file modification time
|
||||
has changed since the last load
|
||||
"""
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
try:
|
||||
prepare_client_config(config_data.get("clients", {}))
|
||||
config = Config(**config_data)
|
||||
config.recent_scenes.clean()
|
||||
except pydantic.ValidationError as e:
|
||||
@@ -354,7 +389,6 @@ def save_config(config, file_path: str = "./config.yaml"):
|
||||
elif isinstance(config, dict):
|
||||
# validate
|
||||
try:
|
||||
prepare_client_config(config.get("clients", {}))
|
||||
config = Config(**config).model_dump(exclude_none=True)
|
||||
except pydantic.ValidationError as e:
|
||||
log.error("config validation", error=e)
|
||||
|
||||
@@ -37,9 +37,20 @@ Based on {{ talking_character.name}}'s example dialogue style, create a continua
|
||||
|
||||
You may chose to have {{ talking_character.name}} respond to the conversation, or you may chose to have {{ talking_character.name}} perform a new action that is in line with {{ talking_character.name}}'s character.
|
||||
|
||||
{% if scene.conversation_format == "movie_script" -%}
|
||||
The format is a movie script, so you should write the character's name in all caps followed by a line break and then the character's dialogue. For example:
|
||||
|
||||
CHARACTER NAME
|
||||
I'm so glad you're here.
|
||||
|
||||
Emotions and actions should be written in italics. For example:
|
||||
|
||||
CHARACTER NAME
|
||||
*smiles* I'm so glad you're here.
|
||||
{% else -%}
|
||||
Always contain actions in asterisks. For example, *{{ talking_character.name}} smiles*.
|
||||
Always contain dialogue in quotation marks. For example, {{ talking_character.name}}: "Hello!"
|
||||
|
||||
{% endif -%}
|
||||
{{ extra_instructions }}
|
||||
|
||||
{% if scene.count_character_messages(talking_character) >= 5 %}Use an informal and colloquial register with a conversational tone. Overall, {{ talking_character.name }}'s dialog is Informal, conversational, natural, and spontaneous, with a sense of immediacy.
|
||||
@@ -93,7 +104,7 @@ Always contain dialogue in quotation marks. For example, {{ talking_character.na
|
||||
{% endfor %}
|
||||
{% endblock -%}
|
||||
<|CLOSE_SECTION|>
|
||||
{% if scene.count_character_messages(talking_character) < 5 %}Use an informal and colloquial register with a conversational tone. Overall, {{ talking_character.name }}'s dialog is Informal, conversational, natural, and spontaneous, with a sense of immediacy. Flesh out additional details by describing {{ talking_character.name }}'s actions and mannerisms within asterisks, e.g. *{{ talking_character.name }} smiles*.
|
||||
{% if scene.count_character_messages(talking_character) < 5 %}(Use an informal and colloquial register with a conversational tone. Overall, {{ talking_character.name }}'s dialog is Informal, conversational, natural, and spontaneous, with a sense of immediacy.)
|
||||
{% endif -%}
|
||||
{% if rerun_context and rerun_context.direction -%}
|
||||
{% if rerun_context.method == 'replace' -%}
|
||||
@@ -104,4 +115,10 @@ Always contain dialogue in quotation marks. For example, {{ talking_character.na
|
||||
# Requested changes: {{ rerun_context.direction }}
|
||||
{% endif -%}
|
||||
{% endif -%}
|
||||
{{ bot_token}}{{ talking_character.name }}:{{ partial_message }}
|
||||
{% if scene.conversation_format == 'movie_script' -%}
|
||||
{{ bot_token }}{{ talking_character.name.upper() }}{% if partial_message %}
|
||||
{{ partial_message }}
|
||||
{% endif %}
|
||||
{% else -%}
|
||||
{{ bot_token }}{{ talking_character.name }}:{{ partial_message }}
|
||||
{% endif -%}
|
||||
@@ -24,5 +24,6 @@ You must provide your answer as a comma delimited list of keywords.
|
||||
Keywords should be ordered: physical appearance, emotion, action, environment, color scheme.
|
||||
You must provide many keywords to describe the character and the environment in great detail.
|
||||
Your answer must be suitable as a stable-diffusion image generation prompt.
|
||||
You must avoid negating of keywords, omit things entirely that aren't there. For example instead of saying "no scars", just dont include the keyword scars at all.
|
||||
<|CLOSE_SECTION|>
|
||||
{{ set_prepared_response(character.name+",")}}
|
||||
@@ -84,6 +84,9 @@ class SceneMessage:
|
||||
def unhide(self):
|
||||
self.hidden = False
|
||||
|
||||
def as_format(self, format: str) -> str:
|
||||
return self.message
|
||||
|
||||
|
||||
@dataclass
|
||||
class CharacterMessage(SceneMessage):
|
||||
@@ -105,6 +108,25 @@ class CharacterMessage(SceneMessage):
|
||||
def raw(self):
|
||||
return self.message.split(":", 1)[1].replace('"', "").replace("*", "").strip()
|
||||
|
||||
@property
|
||||
def as_movie_script(self):
|
||||
"""
|
||||
Returns the dialogue line as a script dialogue line.
|
||||
|
||||
Example:
|
||||
{CHARACTER_NAME}
|
||||
{dialogue}
|
||||
"""
|
||||
|
||||
message = self.message.split(":", 1)[1].replace('"', "").strip()
|
||||
|
||||
return f"\n{self.character_name.upper()}\n{message}\n"
|
||||
|
||||
def as_format(self, format: str) -> str:
|
||||
if format == "movie_script":
|
||||
return self.as_movie_script
|
||||
return self.message
|
||||
|
||||
|
||||
@dataclass
|
||||
class NarratorMessage(SceneMessage):
|
||||
@@ -127,6 +149,12 @@ class DirectorMessage(SceneMessage):
|
||||
|
||||
return f"# Story progression instructions for {char_name}: {message}"
|
||||
|
||||
def as_format(self, format: str) -> str:
|
||||
if format == "movie_script":
|
||||
message = str(self)[2:]
|
||||
return f"\n({message})\n"
|
||||
return self.message
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimePassageMessage(SceneMessage):
|
||||
@@ -152,6 +180,12 @@ class ReinforcementMessage(SceneMessage):
|
||||
question, _ = self.source.split(":", 1)
|
||||
return f"# Internal notes: {question}: {self.message}"
|
||||
|
||||
def as_format(self, format: str) -> str:
|
||||
if format == "movie_script":
|
||||
message = str(self)[2:]
|
||||
return f"\n({message})\n"
|
||||
return self.message
|
||||
|
||||
|
||||
MESSAGES = {
|
||||
"scene": SceneMessage,
|
||||
|
||||
@@ -219,6 +219,9 @@ class WebsocketHandler(Receiver):
|
||||
client.pop("status", None)
|
||||
client_cls = CLIENT_CLASSES.get(client["type"])
|
||||
|
||||
if client.get("model") == "No API key set":
|
||||
client.pop("model", None)
|
||||
|
||||
if not client_cls:
|
||||
log.error("Client type not found", client=client)
|
||||
continue
|
||||
@@ -301,7 +304,13 @@ class WebsocketHandler(Receiver):
|
||||
}
|
||||
|
||||
agent_instance = instance.get_agent(name, **self.agents[name])
|
||||
agent_instance.client = self.llm_clients[client_name]["client"]
|
||||
|
||||
try:
|
||||
agent_instance.client = self.llm_clients[client_name]["client"]
|
||||
except KeyError:
|
||||
self.llm_clients[client_name]["client"] = agent_instance.client = (
|
||||
instance.get_client(client_name)
|
||||
)
|
||||
|
||||
if agent_instance.has_toggle:
|
||||
self.agents[name]["enabled"] = agent["enabled"]
|
||||
@@ -618,9 +627,7 @@ class WebsocketHandler(Receiver):
|
||||
)
|
||||
|
||||
def request_scene_history(self):
|
||||
history = [
|
||||
archived_history["text"] for archived_history in self.scene.archived_history
|
||||
]
|
||||
history = [archived_history for archived_history in self.scene.archived_history]
|
||||
|
||||
self.queue_put(
|
||||
{
|
||||
|
||||
@@ -883,6 +883,10 @@ class Scene(Emitter):
|
||||
def world_state_manager(self):
|
||||
return WorldStateManager(self)
|
||||
|
||||
@property
|
||||
def conversation_format(self):
|
||||
return self.get_helper("conversation").agent.conversation_format
|
||||
|
||||
def set_description(self, description: str):
|
||||
self.description = description
|
||||
|
||||
@@ -1111,8 +1115,7 @@ class Scene(Emitter):
|
||||
"archived_history",
|
||||
data={
|
||||
"history": [
|
||||
archived_history["text"]
|
||||
for archived_history in self.archived_history
|
||||
archived_history for archived_history in self.archived_history
|
||||
]
|
||||
},
|
||||
)
|
||||
@@ -1337,6 +1340,8 @@ class Scene(Emitter):
|
||||
budget_context = int(0.5 * budget)
|
||||
budget_dialogue = int(0.5 * budget)
|
||||
|
||||
conversation_format = self.conversation_format
|
||||
|
||||
# collect dialogue
|
||||
|
||||
count = 0
|
||||
@@ -1358,7 +1363,7 @@ class Scene(Emitter):
|
||||
if count_tokens(parts_dialogue) + count_tokens(message) > budget_dialogue:
|
||||
break
|
||||
|
||||
parts_dialogue.insert(0, message)
|
||||
parts_dialogue.insert(0, message.as_format(conversation_format))
|
||||
|
||||
# collect context, ignore where end > len(history) - count
|
||||
|
||||
@@ -1767,10 +1772,14 @@ class Scene(Emitter):
|
||||
continue_scene = True
|
||||
self.commands = command = commands.Manager(self)
|
||||
|
||||
max_backscroll = (
|
||||
self.config.get("game", {}).get("general", {}).get("max_backscroll", 512)
|
||||
)
|
||||
|
||||
if init and self.history:
|
||||
# history is not empty, so we are continuing a scene
|
||||
# need to emit current messages
|
||||
for item in self.history:
|
||||
for item in self.history[-max_backscroll:]:
|
||||
char_name = item.split(":")[0]
|
||||
try:
|
||||
actor = self.get_character(char_name).actor
|
||||
|
||||
@@ -356,13 +356,13 @@ def clean_paragraph(paragraph: str) -> str:
|
||||
|
||||
def clean_message(message: str) -> str:
|
||||
message = message.strip()
|
||||
message = re.sub(r"\s+", " ", message)
|
||||
message = re.sub(r" +", " ", message)
|
||||
message = message.replace("(", "*").replace(")", "*")
|
||||
message = message.replace("[", "*").replace("]", "*")
|
||||
return message
|
||||
|
||||
|
||||
def clean_dialogue(dialogue: str, main_name: str) -> str:
|
||||
def clean_dialogue_old(dialogue: str, main_name: str) -> str:
|
||||
# re split by \n{not main_name}: with a max count of 1
|
||||
pattern = r"\n(?!{}:).*".format(re.escape(main_name))
|
||||
|
||||
@@ -374,6 +374,36 @@ def clean_dialogue(dialogue: str, main_name: str) -> str:
|
||||
return clean_message(strip_partial_sentences(dialogue))
|
||||
|
||||
|
||||
def clean_dialogue(dialogue: str, main_name: str) -> str:
|
||||
|
||||
cleaned = []
|
||||
|
||||
if not dialogue.startswith(main_name):
|
||||
dialogue = f"{main_name}: {dialogue}"
|
||||
|
||||
for line in dialogue.split("\n"):
|
||||
|
||||
if not cleaned:
|
||||
cleaned.append(line)
|
||||
continue
|
||||
|
||||
if line.startswith(f"{main_name}: "):
|
||||
cleaned.append(line[len(main_name) + 2 :])
|
||||
continue
|
||||
|
||||
# if line is all capitalized
|
||||
# this is likely a new speaker in movie script format, and we
|
||||
# bail
|
||||
if line.strip().isupper():
|
||||
break
|
||||
|
||||
if ":" not in line:
|
||||
cleaned.append(line)
|
||||
continue
|
||||
|
||||
return clean_message(strip_partial_sentences("\n".join(cleaned)))
|
||||
|
||||
|
||||
def clean_id(name: str) -> str:
|
||||
"""
|
||||
Cleans up a id name by removing all characters that aren't a-zA-Z0-9_-
|
||||
|
||||
4
talemate_frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.19.0",
|
||||
"version": "0.21.0",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.19.0",
|
||||
"version": "0.21.0",
|
||||
"dependencies": {
|
||||
"@mdi/font": "7.4.47",
|
||||
"core-js": "^3.8.3",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.20.0",
|
||||
"version": "0.21.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"serve": "vue-cli-service serve",
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
Creator
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
<v-divider></v-divider>
|
||||
<v-window v-model="tab">
|
||||
|
||||
<!-- GAME -->
|
||||
@@ -25,11 +26,12 @@
|
||||
<v-card-text>
|
||||
<v-row>
|
||||
<v-col cols="4">
|
||||
<v-list>
|
||||
<v-list-item @click="gamePageSelected=item.value" :prepend-icon="item.icon" v-for="(item, index) in navigation.game" :key="index">
|
||||
<v-list-item-title>{{ item.title }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
<v-tabs v-model="gamePageSelected" color="primary" direction="vertical">
|
||||
<v-tab v-for="(item, index) in navigation.game" :key="index" :value="item.value">
|
||||
<v-icon class="mr-1">{{ item.icon }}</v-icon>
|
||||
{{ item.title }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
</v-col>
|
||||
<v-col cols="8">
|
||||
<div v-if="gamePageSelected === 'general'">
|
||||
@@ -45,6 +47,11 @@
|
||||
<v-checkbox v-model="app_config.game.general.auto_save" label="Auto save" messages="Automatically save after each game-loop"></v-checkbox>
|
||||
<v-checkbox v-model="app_config.game.general.auto_progress" label="Auto progress" messages="AI automatically progresses after player turn."></v-checkbox>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-row>
|
||||
<v-col cols="6">
|
||||
<v-text-field v-model="app_config.game.general.max_backscroll" type="number" label="Max backscroll" messages="Maximum number of messages to keep in the scene backscroll"></v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
<div v-else-if="gamePageSelected === 'character'">
|
||||
@@ -88,9 +95,13 @@
|
||||
<v-col cols="4">
|
||||
<v-list>
|
||||
<v-list-subheader>Third Party APIs</v-list-subheader>
|
||||
<v-list-item @click="applicationPageSelected=item.value" :prepend-icon="item.icon" v-for="(item, index) in navigation.application" :key="index">
|
||||
<v-list-item-title>{{ item.title }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-tabs v-model="applicationPageSelected" color="primary" direction="vertical" density="compact">
|
||||
<v-tab v-for="(item, index) in navigation.application" :key="index" :value="item.value">
|
||||
<v-icon class="mr-1">{{ item.icon }}</v-icon>
|
||||
{{ item.title }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
</v-list>
|
||||
</v-col>
|
||||
<v-col cols="8">
|
||||
@@ -112,6 +123,40 @@
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
<!-- MISTRAL.AI API -->
|
||||
<div v-if="applicationPageSelected === 'mistralai_api'">
|
||||
<v-alert color="white" variant="text" icon="mdi-api" density="compact">
|
||||
<v-alert-title>mistral.ai</v-alert-title>
|
||||
<div class="text-grey">
|
||||
Configure your mistral.ai API key here. You can get one from <a href="https://console.mistral.ai/api-keys/" target="_blank">https://console.mistral.ai/api-keys/</a>
|
||||
</div>
|
||||
</v-alert>
|
||||
<v-divider class="mb-2"></v-divider>
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
<v-text-field type="password" v-model="app_config.mistralai.api_key"
|
||||
label="mistral.ai API Key"></v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
<!-- ANTHROPIC API -->
|
||||
<div v-if="applicationPageSelected === 'anthropic_api'">
|
||||
<v-alert color="white" variant="text" icon="mdi-api" density="compact">
|
||||
<v-alert-title>Anthropic</v-alert-title>
|
||||
<div class="text-grey">
|
||||
Configure your Anthropic API key here. You can get one from <a href="https://console.anthropic.com/settings/keys" target="_blank">https://console.anthropic.com/settings/keys</a>
|
||||
</div>
|
||||
</v-alert>
|
||||
<v-divider class="mb-2"></v-divider>
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
<v-text-field type="password" v-model="app_config.anthropic.api_key"
|
||||
label="Anthropic API Key"></v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
<!-- ELEVENLABS API -->
|
||||
<div v-if="applicationPageSelected === 'elevenlabs_api'">
|
||||
<v-alert color="white" variant="text" icon="mdi-api" density="compact">
|
||||
@@ -130,23 +175,6 @@
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
<!-- COQUI API -->
|
||||
<div v-if="applicationPageSelected === 'coqui_api'">
|
||||
<v-alert color="white" variant="text" icon="mdi-api" density="compact">
|
||||
<v-alert-title>Coqui Studio</v-alert-title>
|
||||
<div class="text-grey">
|
||||
<p class="mb-1">Realistic, emotive text-to-speech through generative AI.</p>
|
||||
Configure your Coqui API key here. You can get one from <a href="https://app.coqui.ai/account" target="_blank">https://app.coqui.ai/account</a>
|
||||
</div>
|
||||
</v-alert>
|
||||
<v-divider class="mb-2"></v-divider>
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
<v-text-field type="password" v-model="app_config.coqui.api_key"
|
||||
label="Coqui API Key"></v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
<!-- RUNPOD API -->
|
||||
<div v-if="applicationPageSelected === 'runpod_api'">
|
||||
@@ -179,11 +207,12 @@
|
||||
<v-card-text>
|
||||
<v-row>
|
||||
<v-col cols="4">
|
||||
<v-list>
|
||||
<v-list-item @click="creatorPageSelected=item.value" :prepend-icon="item.icon" v-for="(item, index) in navigation.creator" :key="index">
|
||||
<v-list-item-title>{{ item.title }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
<v-tabs v-model="creatorPageSelected" color="primary" direction="vertical">
|
||||
<v-tab v-for="(item, index) in navigation.creator" :key="index" :value="item.value">
|
||||
<v-icon class="mr-1">{{ item.icon }}</v-icon>
|
||||
{{ item.title }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
</v-col>
|
||||
<v-col cols="8">
|
||||
<div v-if="creatorPageSelected === 'content_context'">
|
||||
@@ -248,8 +277,9 @@ export default {
|
||||
],
|
||||
application: [
|
||||
{title: 'OpenAI', icon: 'mdi-api', value: 'openai_api'},
|
||||
{title: 'mistral.ai', icon: 'mdi-api', value: 'mistralai_api'},
|
||||
{title: 'Anthropic', icon: 'mdi-api', value: 'anthropic_api'},
|
||||
{title: 'ElevenLabs', icon: 'mdi-api', value: 'elevenlabs_api'},
|
||||
{title: 'Coqui Studio', icon: 'mdi-api', value: 'coqui_api'},
|
||||
{title: 'RunPod', icon: 'mdi-api', value: 'runpod_api'},
|
||||
],
|
||||
creator: [
|
||||
|
||||
@@ -38,14 +38,15 @@
|
||||
<v-row v-for="field in clientMeta().extra_fields" :key="field.name">
|
||||
<v-col cols="12">
|
||||
<v-text-field v-model="client.data[field.name]" v-if="field.type==='text'" :label="field.label" :rules="[rules.required]" :hint="field.description"></v-text-field>
|
||||
<v-checkbox v-else-if="field.type === 'bool'" v-model="client.data[field.name]" :label="field.label" :hint="field.description" density="compact"></v-checkbox>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-row>
|
||||
<v-col cols="4">
|
||||
<v-text-field v-model="client.max_token_length" v-if="requiresAPIUrl(client)" type="number" label="Context Length" :rules="[rules.required]"></v-text-field>
|
||||
</v-col>
|
||||
<v-col cols="8" v-if="!typeEditable() && client.data && client.data.prompt_template_example !== null && client.model_name && clientMeta().requires_prompt_template">
|
||||
<v-combobox ref="promptTemplateComboBox" label="Prompt Template" v-model="client.data.template_file" @update:model-value="setPromptTemplate" :items="promptTemplates"></v-combobox>
|
||||
<v-col cols="8" v-if="!typeEditable() && client.data && client.data.prompt_template_example !== null && client.model_name && clientMeta().requires_prompt_template && !client.data.api_handles_prompt_template">
|
||||
<v-combobox ref="promptTemplateComboBox" :label="'Prompt Template for '+client.model_name" v-model="client.data.template_file" @update:model-value="setPromptTemplate" :items="promptTemplates"></v-combobox>
|
||||
<v-card elevation="3" :color="(client.data.has_prompt_template ? 'primary' : 'warning')" variant="tonal">
|
||||
|
||||
<v-card-text>
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
History
|
||||
</v-card-title>
|
||||
<v-card-text style="max-height:600px; overflow-y:scroll;">
|
||||
<v-list-item v-for="(text, index) in history" :key="index" class="text-body-2">
|
||||
{{ text }}
|
||||
<v-list-item v-for="(entry, index) in history" :key="index" class="text-body-2">
|
||||
{{ entry.ts }} {{ entry.text }}
|
||||
<v-divider class="mt-1"></v-divider>
|
||||
</v-list-item>
|
||||
</v-card-text>
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"text": "a puppy",
|
||||
"text": "",
|
||||
"clip": [
|
||||
"1",
|
||||
1
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"text": "a puppy",
|
||||
"text": "",
|
||||
"clip": [
|
||||
"1",
|
||||
1
|
||||
|
||||
@@ -26,11 +26,14 @@ def test_dialogue_cleanup(input, expected):
|
||||
|
||||
@pytest.mark.parametrize("input, expected, main_name", [
|
||||
("bob: says a sentence", "bob: says a sentence", "bob"),
|
||||
("bob: says a sentence\nbob: says another sentence", "bob: says a sentence says another sentence", "bob"),
|
||||
("bob: says a sentence\nbob: says another sentence", "bob: says a sentence\nsays another sentence", "bob"),
|
||||
("bob: says a sentence with a colon: to explain something", "bob: says a sentence with a colon: to explain something", "bob"),
|
||||
("bob: i have a riddle for you, alice: the riddle", "bob: i have a riddle for you, alice: the riddle", "bob"),
|
||||
("bob: says something\nalice: says something else", "bob: says something", "bob"),
|
||||
("bob: says a sentence. then a", "bob: says a sentence.", "bob"),
|
||||
("bob: first paragraph\n\nsecond paragraph", "bob: first paragraph\n\nsecond paragraph", "bob"),
|
||||
# movie script new speaker cutoff
|
||||
("bob: says a sentence\n\nALICE\nsays something else", "bob: says a sentence", "bob"),
|
||||
])
|
||||
def test_clean_dialogue(input, expected, main_name):
|
||||
others = ["alice", "charlie"]
|
||||
|
||||