Compare commits

...

10 Commits

Author SHA1 Message Date
fiwo
496eb469db Prep 0.14.0 (#34)
* tts agent first progress

* coqui support
voice lists

* orca-2

* tts tweaks

* switch to ux for audio gen

* some tweaks for the new audio queue

* fix error handling if llm fails to create a good world state on initial scene load

* loading creative mode for a new scene will now ask for confirmation if the current scene has unsaved progress

* local tts support

* fix voice list reloading when switching tts api
fix agent config ux to auto save on change, remove save / close buttons

* only do a delayed save on agent config on text input changes

* OrionStar

* dont allow scene loading when llm agents arent correctly configured

* wire summarization to game loop, summarizer agent configs

* fix issues with time passage

* editor fix narrator messages

* 0.14.0

* poetry lock

* requires_llm_client moved to cls property

* add additional config stubs

* tts still load voices even if the agent is disabled

* fix bugf that would keep losing voice selection for tts agent after backend restart

* update tts install requirements

* remove debug output
2023-11-24 22:08:13 +02:00
FInalWombat
b78fec3bac Update README.md 2023-11-20 00:13:08 +02:00
FInalWombat
d250df8950 Prep 0.13.2 (#33)
* fix issue with client removal

* client type not editable after creation (keeps things simple)

* fixes issue with openai client bugging out (api_url not set)

* fix issues with edit client not reflecting changes to UX

* 0.13.2
2023-11-19 20:43:15 +02:00
FInalWombat
816f950afe Prep 0.13.1 (#29)
* narrate after dialog constrained a bit more so it doesnt create something unrelated

* fix issue where textgenwebui client would come back as status ok even though no model was loaded

* 0.13.1
2023-11-19 18:58:40 +02:00
FInalWombat
8fb72fdbe9 Update README.md 2023-11-19 14:05:09 +02:00
FInalWombat
54297a4768 Update README.md 2023-11-18 12:20:02 +02:00
FInalWombat
d7e72d27c5 Prep 0.13.0 (#28)
* requirements.txt file

* windows installs from requirements.txt because of silly permission issues

* relock

* narrator - narrate on dialogue agent actions

* add support for new textgenwebui api

* world state auto regen trigger off of gameloop

* funciton !rename command

* ensure_dialog_format error handling

* Cat, Nous-Capybara, dolphin-2.2.1

* narrate after dialog rerun fixes, template fixes

* LMStudio client (experimental)

* dolhpin yi

* refactor client base

* cruft

* openai client to new base

* more client refactor fixes

* tweak context retrieval prompts

* adjust nous capybara template

* add Tess-Medium

* 0.13.0

* switch back to poetry for windows as well

* error on legacy textgenwebui api

* runpod text gen api url fixed

* fix windows install script

* add fllow instruction template

* Psyfighter2
2023-11-18 12:16:29 +02:00
FInalWombat
f9b23f8705 Update README.md 2023-11-14 11:06:52 +02:00
FInalWombat
37a5873330 Update README.md 2023-11-12 15:43:02 +02:00
FInalWombat
bc3f5d63c8 Add files via upload 2023-11-12 15:42:07 +02:00
64 changed files with 3531 additions and 1578 deletions

View File

@@ -4,22 +4,26 @@ Allows you to play roleplay scenarios with large language models.
It does not run any large language models itself but relies on existing APIs. Currently supports **text-generation-webui** and **openai**. It does not run any large language models itself but relies on existing APIs. Currently supports **text-generation-webui** and **openai**.
This means you need to either have an openai api key or know how to setup [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (locally or remotely via gpu renting. `--api` flag needs to be set) This means you need to either have an openai api key or know how to setup [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (locally or remotely via gpu renting. `--extension openai` flag needs to be set)
![Screenshot 1](docs/img/Screenshot_8.png) As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
![Screenshot 1](docs/img/Screenshot_9.png)
![Screenshot 2](docs/img/Screenshot_2.png) ![Screenshot 2](docs/img/Screenshot_2.png)
## Current features ## Current features
- responive modern ui - responive modern ui
- agents - agents
- conversation - conversation: handles character dialogue
- narration - narration: handles narrative exposition
- summarization - summarization: handles summarization to compress context while maintain history
- director - director: can be used to direct the story / characters
- creative - editor: improves AI responses (very hit and miss at the moment)
- world state: generates world snapshot and handles passage of time (objects and characters)
- creator: character / scenario creator
- multi-client (agents can be connected to separate APIs) - multi-client (agents can be connected to separate APIs)
- long term memory (experimental) - long term memory
- chromadb integration - chromadb integration
- passage of time - passage of time
- narrative world state - narrative world state
@@ -36,6 +40,7 @@ Kinda making it up as i go along, but i want to lean more into gameplay through
In no particular order: In no particular order:
- TTS support
- Extension support - Extension support
- modular agents and clients - modular agents and clients
- Improved world state - Improved world state
@@ -49,7 +54,7 @@ In no particular order:
- objectives - objectives
- quests - quests
- win / lose conditions - win / lose conditions
- Automatic1111 client - Automatic1111 client for in place visual generation
# Quickstart # Quickstart
@@ -113,6 +118,8 @@ https://www.reddit.com/r/LocalLLaMA/comments/17fhp9k/huge_llm_comparisontest_39_
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button: On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
![Client options](docs/img/client-options-toggle.png) ![Client options](docs/img/client-options-toggle.png)
### Text-generation-webui ### Text-generation-webui
@@ -155,7 +162,10 @@ Make sure you save the scene after the character is loaded as it can then be loa
## Further documentation ## Further documentation
Please read the documents in the `docs` folder for more advanced configuration and usage.
- Creative mode (docs WIP) - Creative mode (docs WIP)
- Prompt template overrides - Prompt template overrides
- [Text-to-Speech (TTS)](docs/tts.md)
- [ChromaDB (long term memory)](docs/chromadb.md) - [ChromaDB (long term memory)](docs/chromadb.md)
- Runpod Integration - Runpod Integration

View File

@@ -14,13 +14,32 @@ game:
gender: male gender: male
name: Elmer name: Elmer
## Long-term memory
#chromadb: #chromadb:
# embeddings: instructor # embeddings: instructor
# instructor_device: cuda # instructor_device: cuda
# instructor_model: hkunlp/instructor-xl # instructor_model: hkunlp/instructor-xl
## Remote LLMs
#openai: #openai:
# api_key: <API_KEY> # api_key: <API_KEY>
#runpod: #runpod:
# api_key: <API_KEY> # api_key: <API_KEY>
## TTS (Text-to-Speech)
#elevenlabs:
# api_key: <API_KEY>
#coqui:
# api_key: <API_KEY>
#tts:
# device: cuda
# model: tts_models/multilingual/multi-dataset/xtts_v2
# voices:
# - label: <name>
# value: <path to .wav for voice sample>

BIN
docs/img/Screenshot_9.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 551 KiB

84
docs/tts.md Normal file
View File

@@ -0,0 +1,84 @@
# Talemate Text-to-Speech (TTS) Configuration
Talemate supports Text-to-Speech (TTS) functionality, allowing users to convert text into spoken audio. This document outlines the steps required to configure TTS for Talemate using different providers, including ElevenLabs, Coqui, and a local TTS API.
## Configuring ElevenLabs TTS
To use ElevenLabs TTS with Talemate, follow these steps:
1. Visit [ElevenLabs](https://elevenlabs.com) and create an account if you don't already have one.
2. Click on your profile in the upper right corner of the Eleven Labs website to access your API key.
3. In the `config.yaml` file, under the `elevenlabs` section, set the `api_key` field with your ElevenLabs API key.
Example configuration snippet:
```yaml
elevenlabs:
api_key: <YOUR_ELEVENLABS_API_KEY>
```
## Configuring Coqui TTS
To use Coqui TTS with Talemate, follow these steps:
1. Visit [Coqui](https://app.coqui.ai) and sign up for an account.
2. Go to the [account page](https://app.coqui.ai/account) and scroll to the bottom to find your API key.
3. In the `config.yaml` file, under the `coqui` section, set the `api_key` field with your Coqui API key.
Example configuration snippet:
```yaml
coqui:
api_key: <YOUR_COQUI_API_KEY>
```
## Configuring Local TTS API
For running a local TTS API, Talemate requires specific dependencies to be installed.
### Windows Installation
Run `install-local-tts.bat` to install the necessary requirements.
### Linux Installation
Execute the following command:
```bash
pip install TTS
```
### Model and Device Configuration
1. Choose a TTS model from the [Coqui TTS model list](https://github.com/coqui-ai/TTS).
2. Decide whether to use `cuda` or `cpu` for the device setting.
3. The first time you run TTS through the local API, it will download the specified model. Please note that this may take some time, and the download progress will be visible in the Talemate backend output.
Example configuration snippet:
```yaml
tts:
device: cuda # or 'cpu'
model: tts_models/multilingual/multi-dataset/xtts_v2
```
### Voice Samples Configuration
Configure voice samples by setting the `value` field to the path of a .wav file voice sample. Official samples can be downloaded from [Coqui XTTS-v2 samples](https://huggingface.co/coqui/XTTS-v2/tree/main/samples).
Example configuration snippet:
```yaml
tts:
voices:
- label: English Male
value: path/to/english_male.wav
- label: English Female
value: path/to/english_female.wav
```
## Saving the Configuration
After configuring the `config.yaml` file, save your changes. Talemate will use the updated settings the next time it starts.
For more detailed information on configuring Talemate, refer to the `config.py` file in the Talemate source code and the `config.example.yaml` file for a barebone configuration example.

4
install-local-tts.bat Normal file
View File

@@ -0,0 +1,4 @@
REM activate the virtual environment
call talemate_env\Scripts\activate
call pip install "TTS>=0.21.1"

View File

@@ -7,10 +7,10 @@ REM activate the virtual environment
call talemate_env\Scripts\activate call talemate_env\Scripts\activate
REM install poetry REM install poetry
python -m pip install poetry "rapidfuzz>=3" -U python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies REM use poetry to install dependencies
poetry install python -m poetry install
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
IF NOT EXIST config.yaml copy config.example.yaml config.yaml IF NOT EXIST config.yaml copy config.example.yaml config.yaml

1982
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
[tool.poetry] [tool.poetry]
name = "talemate" name = "talemate"
version = "0.12.0" version = "0.14.0"
description = "AI-backed roleplay and narrative tools" description = "AI-backed roleplay and narrative tools"
authors = ["FinalWombat"] authors = ["FinalWombat"]
license = "GNU Affero General Public License v3.0" license = "GNU Affero General Public License v3.0"
@@ -37,11 +37,12 @@ nest_asyncio = "^1.5.7"
isodate = ">=0.6.1" isodate = ">=0.6.1"
thefuzz = ">=0.20.0" thefuzz = ">=0.20.0"
tiktoken = ">=0.5.1" tiktoken = ">=0.5.1"
nltk = ">=3.8.1"
# ChromaDB # ChromaDB
chromadb = ">=0.4,<1" chromadb = ">=0.4.17,<1"
InstructorEmbedding = "^1.0.1" InstructorEmbedding = "^1.0.1"
torch = ">=2.0.0, !=2.0.1" torch = ">=2.1.0"
sentence-transformers="^2.2.2" sentence-transformers="^2.2.2"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]

View File

@@ -9,7 +9,7 @@ REM activate the virtual environment
call talemate_env\Scripts\activate call talemate_env\Scripts\activate
REM install poetry REM install poetry
python -m pip install poetry "rapidfuzz>=3" -U python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies REM use poetry to install dependencies
python -m poetry install python -m poetry install

View File

@@ -2,4 +2,4 @@ from .agents import Agent
from .client import TextGeneratorWebuiClient from .client import TextGeneratorWebuiClient
from .tale_mate import * from .tale_mate import *
VERSION = "0.12.0" VERSION = "0.14.0"

View File

@@ -9,3 +9,4 @@ from .registry import AGENT_CLASSES, get_agent_class, register
from .summarize import SummarizeAgent from .summarize import SummarizeAgent
from .editor import EditorAgent from .editor import EditorAgent
from .world_state import WorldStateAgent from .world_state import WorldStateAgent
from .tts import TTSAgent

View File

@@ -23,16 +23,31 @@ __all__ = [
log = structlog.get_logger("talemate.agents.base") log = structlog.get_logger("talemate.agents.base")
class CallableConfigValue:
def __init__(self, fn):
self.fn = fn
def __str__(self):
return "CallableConfigValue"
def __repr__(self):
return "CallableConfigValue"
class AgentActionConfig(pydantic.BaseModel): class AgentActionConfig(pydantic.BaseModel):
type: str type: str
label: str label: str
description: str = "" description: str = ""
value: Union[int, float, str, bool] value: Union[int, float, str, bool, None]
default_value: Union[int, float, str, bool] = None default_value: Union[int, float, str, bool] = None
max: Union[int, float, None] = None max: Union[int, float, None] = None
min: Union[int, float, None] = None min: Union[int, float, None] = None
step: Union[int, float, None] = None step: Union[int, float, None] = None
scope: str = "global" scope: str = "global"
choices: Union[list[dict[str, str]], None] = None
class Config:
arbitrary_types_allowed = True
class AgentAction(pydantic.BaseModel): class AgentAction(pydantic.BaseModel):
enabled: bool = True enabled: bool = True
@@ -40,7 +55,6 @@ class AgentAction(pydantic.BaseModel):
description: str = "" description: str = ""
config: Union[dict[str, AgentActionConfig], None] = None config: Union[dict[str, AgentActionConfig], None] = None
def set_processing(fn): def set_processing(fn):
""" """
decorator that emits the agent status as processing while the function decorator that emits the agent status as processing while the function
@@ -70,6 +84,7 @@ class Agent(ABC):
agent_type = "agent" agent_type = "agent"
verbose_name = None verbose_name = None
set_processing = set_processing set_processing = set_processing
requires_llm_client = True
@property @property
def agent_details(self): def agent_details(self):
@@ -89,7 +104,7 @@ class Agent(ABC):
if not getattr(self.client, "enabled", True): if not getattr(self.client, "enabled", True):
return False return False
if self.client.current_status in ["error", "warning"]: if self.client and self.client.current_status in ["error", "warning"]:
return False return False
return self.client is not None return self.client is not None
@@ -135,6 +150,7 @@ class Agent(ABC):
"enabled": agent.enabled if agent else True, "enabled": agent.enabled if agent else True,
"has_toggle": agent.has_toggle if agent else False, "has_toggle": agent.has_toggle if agent else False,
"experimental": agent.experimental if agent else False, "experimental": agent.experimental if agent else False,
"requires_llm_client": cls.requires_llm_client,
} }
actions = getattr(agent, "actions", None) actions = getattr(agent, "actions", None)

View File

@@ -406,7 +406,7 @@ class ConversationAgent(Agent):
context = await memory.multi_query(history, max_tokens=500, iterate=5) context = await memory.multi_query(history, max_tokens=500, iterate=5)
self.current_memory_context = "\n".join(context) self.current_memory_context = "\n\n".join(context)
return self.current_memory_context return self.current_memory_context

View File

@@ -10,7 +10,7 @@ import talemate.emit.async_signals
from talemate.prompts import Prompt from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage from talemate.scene_message import DirectorMessage, TimePassageMessage
from .base import Agent, set_processing, AgentAction from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .registry import register from .registry import register
import structlog import structlog
@@ -21,6 +21,7 @@ import re
if TYPE_CHECKING: if TYPE_CHECKING:
from talemate.tale_mate import Actor, Character, Scene from talemate.tale_mate import Actor, Character, Scene
from talemate.agents.conversation import ConversationAgentEmission from talemate.agents.conversation import ConversationAgentEmission
from talemate.agents.narrator import NarratorAgentEmission
log = structlog.get_logger("talemate.agents.editor") log = structlog.get_logger("talemate.agents.editor")
@@ -40,7 +41,9 @@ class EditorAgent(Agent):
self.is_enabled = True self.is_enabled = True
self.actions = { self.actions = {
"edit_dialogue": AgentAction(enabled=False, label="Edit dialogue", description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue."), "edit_dialogue": AgentAction(enabled=False, label="Edit dialogue", description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue."),
"fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue."), "fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.", config={
"narrator": AgentActionConfig(type="bool", label="Fix narrator messages", description="Will attempt to fix exposition issues in narrator messages", value=True),
}),
"add_detail": AgentAction(enabled=False, label="Add detail", description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.") "add_detail": AgentAction(enabled=False, label="Add detail", description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.")
} }
@@ -59,6 +62,7 @@ class EditorAgent(Agent):
def connect(self, scene): def connect(self, scene):
super().connect(scene) super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated) talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
talemate.emit.async_signals.get("agent.narrator.generated").connect(self.on_narrator_generated)
async def on_conversation_generated(self, emission:ConversationAgentEmission): async def on_conversation_generated(self, emission:ConversationAgentEmission):
""" """
@@ -93,6 +97,24 @@ class EditorAgent(Agent):
emission.generation = edited emission.generation = edited
async def on_narrator_generated(self, emission:NarratorAgentEmission):
"""
Called when a narrator message is generated
"""
if not self.enabled:
return
log.info("editing narrator", emission=emission)
edited = []
for text in emission.generation:
edit = await self.fix_exposition_on_narrator(text)
edited.append(edit)
emission.generation = edited
@set_processing @set_processing
async def edit_conversation(self, content:str, character:Character): async def edit_conversation(self, content:str, character:Character):
@@ -127,12 +149,13 @@ class EditorAgent(Agent):
if not self.actions["fix_exposition"].enabled: if not self.actions["fix_exposition"].enabled:
return content return content
#response = await Prompt.request("editor.fix-exposition", self.client, "edit_fix_exposition", vars={ if not character.is_player:
# "content": content, if '"' not in content and '*' not in content:
# "character": character, content = util.strip_partial_sentences(content)
# "scene": self.scene, character_prefix = f"{character.name}: "
# "max_length": self.client.max_token_length message = content.split(character_prefix)[1]
#}) content = f"{character_prefix}*{message.strip('*')}*"
return content
content = util.clean_dialogue(content, main_name=character.name) content = util.clean_dialogue(content, main_name=character.name)
content = util.strip_partial_sentences(content) content = util.strip_partial_sentences(content)
@@ -140,6 +163,24 @@ class EditorAgent(Agent):
return content return content
@set_processing
async def fix_exposition_on_narrator(self, content:str):
if not self.actions["fix_exposition"].enabled:
return content
if not self.actions["fix_exposition"].config["narrator"].value:
return content
content = util.strip_partial_sentences(content)
if '"' not in content:
content = f"*{content.strip('*')}*"
else:
content = util.ensure_dialog_format(content)
return content
@set_processing @set_processing
async def add_detail(self, content:str, character:Character): async def add_detail(self, content:str, character:Character):
""" """

View File

@@ -206,6 +206,7 @@ from .registry import register
@register(condition=lambda: chromadb is not None) @register(condition=lambda: chromadb is not None)
class ChromaDBMemoryAgent(MemoryAgent): class ChromaDBMemoryAgent(MemoryAgent):
requires_llm_client = False
@property @property
def ready(self): def ready(self):
@@ -328,9 +329,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
model_name=instructor_model, device=instructor_device model_name=instructor_model, device=instructor_device
) )
log.info("chromadb", status="embedding function ready")
self.db = self.db_client.get_or_create_collection( self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef collection_name, embedding_function=ef
) )
log.info("chromadb", status="instructor db ready")
else: else:
log.info("chromadb", status="using default embeddings") log.info("chromadb", status="using default embeddings")
self.db = self.db_client.get_or_create_collection(collection_name) self.db = self.db_client.get_or_create_collection(collection_name)
@@ -405,7 +410,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
id = uid or f"__narrator__-{self.memory_tracker['__narrator__']}" id = uid or f"__narrator__-{self.memory_tracker['__narrator__']}"
ids = [id] ids = [id]
log.debug("chromadb agent add", text=text, meta=meta, id=id) #log.debug("chromadb agent add", text=text, meta=meta, id=id)
self.db.upsert(documents=[text], metadatas=metadatas, ids=ids) self.db.upsert(documents=[text], metadatas=metadatas, ids=ids)
@@ -461,6 +466,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
#import json #import json
#print(json.dumps(_results["ids"], indent=2)) #print(json.dumps(_results["ids"], indent=2))
#print(json.dumps(_results["distances"], indent=2))
results = [] results = []
@@ -474,9 +480,10 @@ class ChromaDBMemoryAgent(MemoryAgent):
if distance < 1: if distance < 1:
try: try:
log.debug("chromadb agent get", ts=ts, scene_ts=self.scene.ts)
date_prefix = util.iso8601_diff_to_human(ts, self.scene.ts) date_prefix = util.iso8601_diff_to_human(ts, self.scene.ts)
except Exception: except Exception as e:
log.error("chromadb agent", error="failed to get date prefix", ts=ts, scene_ts=self.scene.ts) log.error("chromadb agent", error="failed to get date prefix", details=e, ts=ts, scene_ts=self.scene.ts)
date_prefix = None date_prefix = None
if date_prefix: if date_prefix:

View File

@@ -1,22 +1,60 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Callable, List, Optional, Union from typing import TYPE_CHECKING, Callable, List, Optional, Union
import dataclasses
import structlog import structlog
import random
import talemate.util as util import talemate.util as util
from talemate.emit import emit from talemate.emit import emit
import talemate.emit.async_signals import talemate.emit.async_signals
from talemate.prompts import Prompt from talemate.prompts import Prompt
from talemate.agents.base import set_processing, Agent, AgentAction, AgentActionConfig from talemate.agents.base import set_processing as _set_processing, Agent, AgentAction, AgentActionConfig, AgentEmission
from talemate.agents.world_state import TimePassageEmission from talemate.agents.world_state import TimePassageEmission
from talemate.scene_message import NarratorMessage from talemate.scene_message import NarratorMessage
from talemate.events import GameLoopActorIterEvent
import talemate.client as client import talemate.client as client
from .registry import register from .registry import register
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Player, Character
log = structlog.get_logger("talemate.agents.narrator") log = structlog.get_logger("talemate.agents.narrator")
@dataclasses.dataclass
class NarratorAgentEmission(AgentEmission):
generation: list[str] = dataclasses.field(default_factory=list)
talemate.emit.async_signals.register(
"agent.narrator.generated"
)
def set_processing(fn):
"""
Custom decorator that emits the agent status as processing while the function
is running and then emits the result of the function as a NarratorAgentEmission
"""
@_set_processing
async def wrapper(self, *args, **kwargs):
response = await fn(self, *args, **kwargs)
emission = NarratorAgentEmission(
agent=self,
generation=[response],
)
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
return emission.generation[0]
wrapper.__name__ = fn.__name__
return wrapper
@register() @register()
class NarratorAgent(Agent): class NarratorAgent(Agent):
"""
Handles narration of the story
"""
agent_type = "narrator" agent_type = "narrator"
verbose_name = "Narrator" verbose_name = "Narrator"
@@ -27,31 +65,78 @@ class NarratorAgent(Agent):
): ):
self.client = client self.client = client
# agent actions
self.actions = { self.actions = {
"narrate_time_passage": AgentAction(enabled=False, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"), "narrate_time_passage": AgentAction(enabled=True, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"),
"narrate_dialogue": AgentAction(
enabled=True,
label="Narrate Dialogue",
description="Narrator will get a chance to narrate after every line of dialogue",
config = {
"ai_dialog": AgentActionConfig(
type="number",
label="AI Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.3,
min=0.0,
max=1.0,
step=0.1,
),
"player_dialog": AgentActionConfig(
type="number",
label="Player Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.3,
min=0.0,
max=1.0,
step=0.1,
),
}
),
} }
def clean_result(self, result): def clean_result(self, result):
"""
Cleans the result of a narration
"""
result = result.strip().strip(":").strip() result = result.strip().strip(":").strip()
if "#" in result: if "#" in result:
result = result.split("#")[0] result = result.split("#")[0]
character_names = [c.name for c in self.scene.get_characters()]
cleaned = [] cleaned = []
for line in result.split("\n"): for line in result.split("\n"):
if ":" in line.strip(): for character_name in character_names:
break if line.startswith(f"{character_name}:"):
break
cleaned.append(line) cleaned.append(line)
return "\n".join(cleaned) result = "\n".join(cleaned)
#result = util.strip_partial_sentences(result)
return result
def connect(self, scene): def connect(self, scene):
"""
Connect to signals
"""
super().connect(scene) super().connect(scene)
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage) talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
async def on_time_passage(self, event:TimePassageEmission): async def on_time_passage(self, event:TimePassageEmission):
"""
Handles time passage narration, if enabled
"""
if not self.actions["narrate_time_passage"].enabled: if not self.actions["narrate_time_passage"].enabled:
return return
@@ -60,6 +145,31 @@ class NarratorAgent(Agent):
emit("narrator", narrator_message) emit("narrator", narrator_message)
self.scene.push_history(narrator_message) self.scene.push_history(narrator_message)
async def on_dialog(self, event:GameLoopActorIterEvent):
"""
Handles dialogue narration, if enabled
"""
if not self.actions["narrate_dialogue"].enabled:
return
narrate_on_ai_chance = random.random() < self.actions["narrate_dialogue"].config["ai_dialog"].value
narrate_on_player_chance = random.random() < self.actions["narrate_dialogue"].config["player_dialog"].value
log.debug("narrate on dialog", narrate_on_ai_chance=narrate_on_ai_chance, narrate_on_player_chance=narrate_on_player_chance)
if event.actor.character.is_player and not narrate_on_player_chance:
return
if not event.actor.character.is_player and not narrate_on_ai_chance:
return
response = await self.narrate_after_dialogue(event.actor.character)
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
@set_processing @set_processing
async def narrate_scene(self): async def narrate_scene(self):
""" """
@@ -155,8 +265,9 @@ class NarratorAgent(Agent):
"as_narrative": as_narrative, "as_narrative": as_narrative,
} }
) )
log.info("narrate_query", response=response)
response = self.clean_result(response.strip()) response = self.clean_result(response.strip())
log.info("narrate_query (after clean)", response=response)
if as_narrative: if as_narrative:
response = f"*{response}*" response = f"*{response}*"
@@ -266,3 +377,29 @@ class NarratorAgent(Agent):
response = f"*{response}*" response = f"*{response}*"
return response return response
@set_processing
async def narrate_after_dialogue(self, character:Character):
"""
Narrate after a line of dialogue
"""
response = await Prompt.request(
"narrator.narrate-after-dialogue",
self.client,
"narrate",
vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"last_line": str(self.scene.history[-1])
}
)
log.info("narrate_after_dialogue", response=response)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
return response

View File

@@ -5,11 +5,13 @@ import traceback
from typing import TYPE_CHECKING, Callable, List, Optional, Union from typing import TYPE_CHECKING, Callable, List, Optional, Union
import talemate.data_objects as data_objects import talemate.data_objects as data_objects
import talemate.emit.async_signals
import talemate.util as util import talemate.util as util
from talemate.prompts import Prompt from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage from talemate.scene_message import DirectorMessage, TimePassageMessage
from talemate.events import GameLoopEvent
from .base import Agent, set_processing from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .registry import register from .registry import register
import structlog import structlog
@@ -35,12 +37,38 @@ class SummarizeAgent(Agent):
def __init__(self, client, **kwargs): def __init__(self, client, **kwargs):
self.client = client self.client = client
def on_history_add(self, event): self.actions = {
asyncio.ensure_future(self.build_archive(event.scene)) "archive": AgentAction(
enabled=True,
label="Summarize to long-term memory archive",
description="Automatically summarize scene dialogue when the number of tokens in the history exceeds a threshold. This helps keep the context history from growing too large.",
config={
"threshold": AgentActionConfig(
type="number",
label="Token Threshold",
description="Will summarize when the number of tokens in the history exceeds this threshold",
min=512,
max=8192,
step=256,
value=1536,
)
}
)
}
def connect(self, scene): def connect(self, scene):
super().connect(scene) super().connect(scene)
scene.signals["history_add"].connect(self.on_history_add) talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, emission:GameLoopEvent):
"""
Called when a conversation is generated
"""
await self.build_archive(self.scene)
def clean_result(self, result): def clean_result(self, result):
if "#" in result: if "#" in result:
@@ -53,21 +81,31 @@ class SummarizeAgent(Agent):
return result return result
@set_processing @set_processing
async def build_archive(self, scene, token_threshold:int=1500): async def build_archive(self, scene):
end = None end = None
if not self.actions["archive"].enabled:
return
if not scene.archived_history: if not scene.archived_history:
start = 0 start = 0
recent_entry = None recent_entry = None
else: else:
recent_entry = scene.archived_history[-1] recent_entry = scene.archived_history[-1]
start = recent_entry.get("end", 0) + 1 if "end" not in recent_entry:
# permanent historical archive entry, not tied to any specific history entry
# meaning we are still at the beginning of the scene
start = 0
else:
start = recent_entry.get("end", 0)+1
tokens = 0 tokens = 0
dialogue_entries = [] dialogue_entries = []
ts = "PT0S" ts = "PT0S"
time_passage_termination = False time_passage_termination = False
token_threshold = self.actions["archive"].config["threshold"].value
log.debug("build_archive", start=start, recent_entry=recent_entry) log.debug("build_archive", start=start, recent_entry=recent_entry)
if recent_entry: if recent_entry:
@@ -75,6 +113,9 @@ class SummarizeAgent(Agent):
for i in range(start, len(scene.history)): for i in range(start, len(scene.history)):
dialogue = scene.history[i] dialogue = scene.history[i]
#log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
if isinstance(dialogue, DirectorMessage): if isinstance(dialogue, DirectorMessage):
if i == start: if i == start:
start += 1 start += 1
@@ -131,7 +172,7 @@ class SummarizeAgent(Agent):
break break
adjusted_dialogue.append(line) adjusted_dialogue.append(line)
dialogue_entries = adjusted_dialogue dialogue_entries = adjusted_dialogue
end = start + len(dialogue_entries) end = start + len(dialogue_entries)-1
if dialogue_entries: if dialogue_entries:
summarized = await self.summarize( summarized = await self.summarize(

595
src/talemate/agents/tts.py Normal file
View File

@@ -0,0 +1,595 @@
from __future__ import annotations
from typing import Union
import asyncio
import httpx
import io
import os
import pydantic
import nltk
import tempfile
import base64
import uuid
import functools
from nltk.tokenize import sent_tokenize
import talemate.config as config
import talemate.emit.async_signals
from talemate.emit import emit
from talemate.events import GameLoopNewMessageEvent
from talemate.scene_message import CharacterMessage, NarratorMessage
from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .registry import register
import structlog
import time
try:
from TTS.api import TTS
except ImportError:
TTS = None
log = structlog.get_logger("talemate.agents.tts")#
if not TTS:
# TTS installation is massive and requires a lot of dependencies
# so we don't want to require it unless the user wants to use it
log.info("TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api")
nltk.download("punkt")
def parse_chunks(text):
text = text.replace("...", "__ellipsis__")
chunks = sent_tokenize(text)
cleaned_chunks = []
for chunk in chunks:
chunk = chunk.replace("*","")
if not chunk:
continue
cleaned_chunks.append(chunk)
for i, chunk in enumerate(cleaned_chunks):
chunk = chunk.replace("__ellipsis__", "...")
cleaned_chunks[i] = chunk
return cleaned_chunks
def rejoin_chunks(chunks:list[str], chunk_size:int=250):
"""
Will combine chunks split by punctuation into a single chunk until
max chunk size is reached
"""
joined_chunks = []
current_chunk = ""
for chunk in chunks:
if len(current_chunk) + len(chunk) > chunk_size:
joined_chunks.append(current_chunk)
current_chunk = ""
current_chunk += chunk
if current_chunk:
joined_chunks.append(current_chunk)
return joined_chunks
class Voice(pydantic.BaseModel):
value:str
label:str
class VoiceLibrary(pydantic.BaseModel):
api: str
voices: list[Voice] = pydantic.Field(default_factory=list)
last_synced: float = None
@register()
class TTSAgent(Agent):
"""
Text to speech agent
"""
agent_type = "tts"
verbose_name = "Text to speech"
requires_llm_client = False
@classmethod
def config_options(cls, agent=None):
config_options = super().config_options(agent=agent)
if agent:
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
voice.model_dump() for voice in agent.list_voices_sync()
]
return config_options
def __init__(self, **kwargs):
self.is_enabled = False
self.voices = {
"elevenlabs": VoiceLibrary(api="elevenlabs"),
"coqui": VoiceLibrary(api="coqui"),
"tts": VoiceLibrary(api="tts"),
}
self.config = config.load_config()
self.playback_done_event = asyncio.Event()
self.actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
description="TTS agent configuration",
config={
"api": AgentActionConfig(
type="text",
choices=[
# TODO at local TTS support
{"value": "tts", "label": "TTS (Local)"},
{"value": "elevenlabs", "label": "Eleven Labs"},
{"value": "coqui", "label": "Coqui Studio"},
],
value="tts",
label="API",
description="Which TTS API to use",
onchange="emit",
),
"voice_id": AgentActionConfig(
type="text",
value="default",
label="Narrator Voice",
description="Voice ID/Name to use for TTS",
choices=[]
),
"generate_for_player": AgentActionConfig(
type="bool",
value=False,
label="Generate for player",
description="Generate audio for player messages",
),
"generate_for_npc": AgentActionConfig(
type="bool",
value=True,
label="Generate for NPCs",
description="Generate audio for NPC messages",
),
"generate_for_narration": AgentActionConfig(
type="bool",
value=True,
label="Generate for narration",
description="Generate audio for narration messages",
),
"generate_chunks": AgentActionConfig(
type="bool",
value=True,
label="Split generation",
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
)
}
),
}
self.actions["_config"].model_dump()
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def not_ready_reason(self) -> str:
"""
Returns a string explaining why the agent is not ready
"""
if self.ready:
return ""
if self.api == "tts":
if not TTS:
return "TTS not installed"
elif self.requires_token and not self.token:
return "No API token"
elif not self.default_voice_id:
return "No voice selected"
@property
def agent_details(self):
suffix = ""
if not self.ready:
suffix = f" - {self.not_ready_reason}"
else:
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
api = self.api
choices = self.actions["_config"].config["api"].choices
api_label = api
for choice in choices:
if choice["value"] == api:
api_label = choice["label"]
break
return f"{api_label}{suffix}"
@property
def api(self):
return self.actions["_config"].config["api"].value
@property
def token(self):
api = self.api
return self.config.get(api,{}).get("api_key")
@property
def default_voice_id(self):
return self.actions["_config"].config["voice_id"].value
@property
def requires_token(self):
return self.api != "tts"
@property
def ready(self):
if self.api == "tts":
if not TTS:
return False
return True
return (not self.requires_token or self.token) and self.default_voice_id
@property
def status(self):
if not self.enabled:
return "disabled"
if self.ready:
return "active" if not getattr(self, "processing", False) else "busy"
if self.requires_token and not self.token:
return "error"
if self.api == "tts":
if not TTS:
return "error"
@property
def max_generation_length(self):
if self.api == "elevenlabs":
return 1024
elif self.api == "coqui":
return 250
return 250
def apply_config(self, *args, **kwargs):
try:
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
except KeyError:
api = self.api
api_changed = api != self.api
log.debug("apply_config", api=api, api_changed=api != self.api, current_api=self.api)
super().apply_config(*args, **kwargs)
if api_changed:
try:
self.actions["_config"].config["voice_id"].value = self.voices[api].voices[0].value
except IndexError:
self.actions["_config"].config["voice_id"].value = ""
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop_new_message").connect(self.on_game_loop_new_message)
async def on_game_loop_new_message(self, emission:GameLoopNewMessageEvent):
"""
Called when a conversation is generated
"""
if not self.enabled:
return
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
return
if isinstance(emission.message, NarratorMessage) and not self.actions["_config"].config["generate_for_narration"].value:
return
if isinstance(emission.message, CharacterMessage):
if emission.message.source == "player" and not self.actions["_config"].config["generate_for_player"].value:
return
elif emission.message.source == "ai" and not self.actions["_config"].config["generate_for_npc"].value:
return
if isinstance(emission.message, CharacterMessage):
character_prefix = emission.message.split(":", 1)[0]
else:
character_prefix = ""
log.info("reactive tts", message=emission.message, character_prefix=character_prefix)
await self.generate(str(emission.message).replace(character_prefix+": ", ""))
def voice(self, voice_id:str) -> Union[Voice, None]:
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice
return None
def voice_id_to_label(self, voice_id:str):
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice.label
return None
def list_voices_sync(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.list_voices())
async def list_voices(self):
if self.requires_token and not self.token:
return []
library = self.voices[self.api]
log.info("Listing voices", api=self.api, last_synced=library.last_synced)
# TODO: allow re-syncing voices
if library.last_synced:
return library.voices
list_fn = getattr(self, f"_list_voices_{self.api}")
log.info("Listing voices", api=self.api)
library.voices = await list_fn()
library.last_synced = time.time()
# if the current voice cannot be found, reset it
if not self.voice(self.default_voice_id):
self.actions["_config"].config["voice_id"].value = ""
# set loading to false
return library.voices
@set_processing
async def generate(self, text: str):
if not self.enabled or not self.ready or not text:
return
self.playback_done_event.set()
generate_fn = getattr(self, f"_generate_{self.api}")
if self.actions["_config"].config["generate_chunks"].value:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks)
else:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks, chunk_size=self.max_generation_length)
# Start generating audio chunks in the background
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
# Wait for both tasks to complete
await asyncio.gather(generation_task)
async def generate_chunks(self, generate_fn, chunks):
for chunk in chunks:
chunk = chunk.replace("*","").strip()
log.info("Generating audio", api=self.api, chunk=chunk)
audio_data = await generate_fn(chunk)
self.play_audio(audio_data)
def play_audio(self, audio_data):
# play audio through the python audio player
#play(audio_data)
emit("audio_queue", data={"audio_data": base64.b64encode(audio_data).decode("utf-8")})
self.playback_done_event.set() # Signal that playback is finished
# LOCAL
async def _generate_tts(self, text: str) -> Union[bytes, None]:
if not TTS:
return
tts_config = self.config.get("tts",{})
model = tts_config.get("model")
device = tts_config.get("device", "cpu")
log.debug("tts local", model=model, device=device)
if not hasattr(self, "tts_instance"):
self.tts_instance = TTS(model).to(device)
tts = self.tts_instance
loop = asyncio.get_event_loop()
voice = self.voice(self.default_voice_id)
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(None, functools.partial(tts.tts_to_file, text=text, speaker_wav=voice.value, language="en", file_path=file_path))
#tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
with open(file_path, "rb") as f:
return f.read()
async def _list_voices_tts(self) -> dict[str, str]:
return [Voice(**voice) for voice in self.config.get("tts",{}).get("voices",[])]
# ELEVENLABS
async def _generate_elevenlabs(self, text: str, chunk_size: int = 1024) -> Union[bytes, None]:
api_key = self.token
if not api_key:
return
async with httpx.AsyncClient() as client:
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.default_voice_id}"
headers = {
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": api_key,
}
data = {
"text": text,
"model_id": "eleven_monolingual_v1",
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.5
}
}
response = await client.post(url, json=data, headers=headers, timeout=300)
if response.status_code == 200:
bytes_io = io.BytesIO()
for chunk in response.iter_bytes(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
# Put the audio data in the queue for playback
return bytes_io.getvalue()
else:
log.error(f"Error generating audio: {response.text}")
async def _list_voices_elevenlabs(self) -> dict[str, str]:
url_voices = "https://api.elevenlabs.io/v1/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Accept": "application/json",
"xi-api-key": self.token,
}
response = await client.get(url_voices, headers=headers, params={"per_page":1000})
speakers = response.json()["voices"]
voices.extend([Voice(value=speaker["voice_id"], label=speaker["name"]) for speaker in speakers])
# sort by name
voices.sort(key=lambda x: x.label)
return voices
# COQUI STUDIO
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
api_key = self.token
if not api_key:
return
async with httpx.AsyncClient() as client:
url = "https://app.coqui.ai/api/v2/samples/xtts/render/"
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
data = {
"voice_id": self.default_voice_id,
"text": text,
"language": "en" # Assuming English language for simplicity; this could be parameterized
}
# Make the POST request to Coqui API
response = await client.post(url, json=data, headers=headers, timeout=300)
if response.status_code in [200, 201]:
# Parse the JSON response to get the audio URL
response_data = response.json()
audio_url = response_data.get('audio_url')
if audio_url:
# Make a GET request to download the audio file
audio_response = await client.get(audio_url)
if audio_response.status_code == 200:
# delete the sample from Coqui Studio
# await self._cleanup_coqui(response_data.get('id'))
return audio_response.content
else:
log.error(f"Error downloading audio: {audio_response.text}")
else:
log.error("No audio URL in response")
else:
log.error(f"Error generating audio: {response.text}")
async def _cleanup_coqui(self, sample_id: str):
api_key = self.token
if not api_key or not sample_id:
return
async with httpx.AsyncClient() as client:
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
headers = {
"Authorization": f"Bearer {api_key}"
}
# Make the DELETE request to Coqui API
response = await client.delete(url, headers=headers)
if response.status_code == 204:
log.info(f"Successfully deleted sample with ID: {sample_id}")
else:
log.error(f"Error deleting sample with ID: {sample_id}: {response.text}")
async def _list_voices_coqui(self) -> dict[str, str]:
url_speakers = "https://app.coqui.ai/api/v2/speakers"
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"Bearer {self.token}"
}
response = await client.get(url_speakers, headers=headers, params={"per_page":1000})
speakers = response.json()["result"]
voices.extend([Voice(value=speaker["id"], label=speaker["name"]) for speaker in speakers])
response = await client.get(url_custom_voices, headers=headers, params={"per_page":1000})
custom_voices = response.json()["result"]
voices.extend([Voice(value=voice["id"], label=voice["name"]) for voice in custom_voices])
# sort by name
voices.sort(key=lambda x: x.label)
return voices

View File

@@ -8,6 +8,7 @@ import talemate.util as util
from talemate.prompts import Prompt from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage from talemate.scene_message import DirectorMessage, TimePassageMessage
from talemate.emit import emit from talemate.emit import emit
from talemate.events import GameLoopEvent
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
from .registry import register from .registry import register
@@ -16,9 +17,6 @@ import structlog
import isodate import isodate
import time import time
if TYPE_CHECKING:
from talemate.agents.conversation import ConversationAgentEmission
log = structlog.get_logger("talemate.agents.world_state") log = structlog.get_logger("talemate.agents.world_state")
@@ -74,7 +72,7 @@ class WorldStateAgent(Agent):
def connect(self, scene): def connect(self, scene):
super().connect(scene) super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated) talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def advance_time(self, duration:str, narrative:str=None): async def advance_time(self, duration:str, narrative:str=None):
""" """
@@ -96,7 +94,7 @@ class WorldStateAgent(Agent):
) )
async def on_conversation_generated(self, emission:ConversationAgentEmission): async def on_game_loop(self, emission:GameLoopEvent):
""" """
Called when a conversation is generated Called when a conversation is generated
""" """
@@ -104,8 +102,7 @@ class WorldStateAgent(Agent):
if not self.enabled: if not self.enabled:
return return
for _ in emission.generation: await self.update_world_state()
await self.update_world_state()
async def update_world_state(self): async def update_world_state(self):
@@ -230,7 +227,7 @@ class WorldStateAgent(Agent):
): ):
response = await Prompt.request( response = await Prompt.request(
"world_state.analyze-and-follow-instruction", "world_state.analyze-text-and-follow-instruction",
self.client, self.client,
"analyze_freeform", "analyze_freeform",
vars = { vars = {

View File

@@ -1,4 +1,6 @@
import os
from talemate.client.openai import OpenAIClient from talemate.client.openai import OpenAIClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.textgenwebui import TextGeneratorWebuiClient from talemate.client.textgenwebui import TextGeneratorWebuiClient
from talemate.client.lmstudio import LMStudioClient
import talemate.client.runpod import talemate.client.runpod

349
src/talemate/client/base.py Normal file
View File

@@ -0,0 +1,349 @@
"""
A unified client base, based on the openai API
"""
import copy
import random
import time
from typing import Callable
import structlog
import logging
from openai import AsyncOpenAI
from talemate.emit import emit
import talemate.instance as instance
import talemate.client.presets as presets
import talemate.client.system_prompts as system_prompts
import talemate.util as util
from talemate.client.context import client_context_attribute
from talemate.client.model_prompts import model_prompt
# Set up logging level for httpx to WARNING to suppress debug logs.
logging.getLogger('httpx').setLevel(logging.WARNING)
REMOTE_SERVICES = [
# TODO: runpod.py should add this to the list
".runpod.net"
]
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
class ClientBase:
api_url: str
model_name: str
name:str = None
enabled: bool = True
current_status: str = None
max_token_length: int = 4096
randomizable_inference_parameters: list[str] = ["temperature"]
processing: bool = False
connected: bool = False
conversation_retries: int = 5
client_type = "base"
def __init__(
self,
api_url: str = None,
name = None,
**kwargs,
):
self.api_url = api_url
self.name = name or self.client_type
self.log = structlog.get_logger(f"client.{self.client_type}")
self.set_client()
def __str__(self):
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
def set_client(self):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def prompt_template(self, sys_msg, prompt):
"""
Applies the appropriate prompt template for the model.
"""
if not self.model_name:
self.log.warning("prompt template not applied", reason="no model loaded")
return f"{sys_msg}\n{prompt}"
return model_prompt(self.model_name, sys_msg, prompt)
def reconfigure(self, **kwargs):
"""
Reconfigures the client.
Keyword Arguments:
- api_url: the API URL to use
- max_token_length: the max token length to use
- enabled: whether the client is enabled
"""
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
def toggle_disabled_if_remote(self):
"""
If the client is targeting a remote recognized service, this
will disable the client.
"""
for service in REMOTE_SERVICES:
if service in self.api_url:
if self.enabled:
self.log.warn("remote service unreachable, disabling client", client=self.name)
self.enabled = False
return True
return False
def get_system_message(self, kind: str) -> str:
"""
Returns the appropriate system message for the given kind of generation
Arguments:
- kind: the kind of generation
"""
# TODO: make extensible
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
return system_prompts.NARRATOR
if "director" in kind:
return system_prompts.DIRECTOR
if "create" in kind:
return system_prompts.CREATOR
if "roleplay" in kind:
return system_prompts.ROLEPLAY
if "conversation" in kind:
return system_prompts.ROLEPLAY
if "editor" in kind:
return system_prompts.EDITOR
if "world_state" in kind:
return system_prompts.WORLD_STATE
if "analyst" in kind:
return system_prompts.ANALYST
if "analyze" in kind:
return system_prompts.ANALYST
return system_prompts.BASIC
def emit_status(self, processing: bool = None):
"""
Sets and emits the client status.
"""
if processing is not None:
self.processing = processing
if not self.enabled:
status = "disabled"
model_name = "Disabled"
elif not self.connected:
status = "error"
model_name = "Could not connect"
elif self.model_name:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
model_name = "No model loaded"
status = "warning"
status_change = status != self.current_status
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
)
if status_change:
instance.emit_agent_status_by_client(self)
async def get_model_name(self):
models = await self.client.models.list()
try:
return models.data[0].id
except IndexError:
return None
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
Raises an error if no model name is returned.
:return: None
"""
if self.processing:
return
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
self.model_name = await self.get_model_name()
except Exception as e:
self.log.warning("client status error", e=e, client=self.name)
self.model_name = None
self.connected = False
self.toggle_disabled_if_remote()
self.emit_status()
return
self.connected = True
if not self.model_name or self.model_name == "None":
self.log.warning("client model not loaded", client=self)
self.emit_status()
return
self.emit_status()
def generate_prompt_parameters(self, kind:str):
parameters = {}
self.tune_prompt_parameters(
presets.configure(parameters, kind, self.max_token_length),
kind
)
return parameters
def tune_prompt_parameters(self, parameters:dict, kind:str):
parameters["stream"] = False
if client_context_attribute("nuke_repetition") > 0.0 and self.jiggle_enabled_for(kind):
self.jiggle_randomness(parameters, offset=client_context_attribute("nuke_repetition"))
fn_tune_kind = getattr(self, f"tune_prompt_parameters_{kind}", None)
if fn_tune_kind:
fn_tune_kind(parameters)
def tune_prompt_parameters_conversation(self, parameters:dict):
conversation_context = client_context_attribute("conversation")
parameters["max_tokens"] = conversation_context.get("length", 96)
dialog_stopping_strings = [
f"{character}:" for character in conversation_context["other_characters"]
]
if "extra_stopping_strings" in parameters:
parameters["extra_stopping_strings"] += dialog_stopping_strings
else:
parameters["extra_stopping_strings"] = dialog_stopping_strings
async def generate(self, prompt:str, parameters:dict, kind:str):
"""
Generates text from the given prompt and parameters.
"""
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
try:
response = await self.client.completions.create(prompt=prompt.strip(), **parameters)
return response.get("choices", [{}])[0].get("text", "")
except Exception as e:
self.log.error("generate error", e=e)
return ""
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
) -> str:
"""
Send a prompt to the AI and return its response.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
try:
self.emit_status(processing=True)
await self.status()
prompt_param = self.generate_prompt_parameters(kind)
finalized_prompt = self.prompt_template(self.get_system_message(kind), prompt).strip()
prompt_param = finalize(prompt_param)
token_length = self.count_tokens(finalized_prompt)
time_start = time.time()
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
self.log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length, parameters=prompt_param)
response = await self.generate(finalized_prompt, prompt_param, kind)
time_end = time.time()
# stopping strings sometimes get appended to the end of the response anyways
# split the response by the first stopping string and take the first part
for stopping_string in STOPPING_STRINGS + extra_stopping_strings:
if stopping_string in response:
response = response.split(stopping_string)[0]
break
emit("prompt_sent", data={
"kind": kind,
"prompt": finalized_prompt,
"response": response,
"prompt_tokens": token_length,
"response_tokens": self.count_tokens(response),
"time": time_end - time_start,
})
return response
finally:
self.emit_status(processing=False)
def count_tokens(self, content:str):
return util.count_tokens(content)
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
def jiggle_enabled_for(self, kind:str):
if kind in ["conversation", "story"]:
return True
if kind.startswith("narrate"):
return True
return False

View File

@@ -0,0 +1,56 @@
from talemate.client.base import ClientBase
from talemate.client.registry import register
from openai import AsyncOpenAI
@register()
class LMStudioClient(ClientBase):
client_type = "lmstudio"
conversation_retries = 5
def set_client(self):
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters:dict, kind:str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
model_name = await super().get_model_name()
# model name comes back as a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash
if model_name:
model_name = model_name.replace("\\", "/").split("/")[-1]
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
"""
Generates text from the given prompt and parameters.
"""
human_message = {'role': 'user', 'content': prompt.strip()}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
)
return response.choices[0].message.content
except Exception as e:
self.log.error("generate error", e=e)
return ""

View File

@@ -1,10 +1,9 @@
import asyncio
import os import os
import time import json
from typing import Callable
from openai import AsyncOpenAI from openai import AsyncOpenAI
from talemate.client.base import ClientBase
from talemate.client.registry import register from talemate.client.registry import register
from talemate.emit import emit from talemate.emit import emit
from talemate.config import load_config from talemate.config import load_config
@@ -15,10 +14,9 @@ import tiktoken
__all__ = [ __all__ = [
"OpenAIClient", "OpenAIClient",
] ]
log = structlog.get_logger("talemate") log = structlog.get_logger("talemate")
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"): def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages.""" """Return the number of tokens used by a list of messages."""
try: try:
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken.encoding_for_model(model)
@@ -70,7 +68,7 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
return num_tokens return num_tokens
@register() @register()
class OpenAIClient: class OpenAIClient(ClientBase):
""" """
OpenAI client for generating text. OpenAI client for generating text.
""" """
@@ -79,13 +77,10 @@ class OpenAIClient:
conversation_retries = 0 conversation_retries = 0
def __init__(self, model="gpt-4-1106-preview", **kwargs): def __init__(self, model="gpt-4-1106-preview", **kwargs):
self.name = kwargs.get("name", "openai")
self.model_name = model self.model_name = model
self.last_token_length = 0
self.max_token_length = 2048
self.processing = False
self.current_status = "idle"
self.config = load_config() self.config = load_config()
super().__init__(**kwargs)
# if os.environ.get("OPENAI_API_KEY") is not set, look in the config file # if os.environ.get("OPENAI_API_KEY") is not set, look in the config file
# and set it # and set it
@@ -94,7 +89,7 @@ class OpenAIClient:
if self.config.get("openai", {}).get("api_key"): if self.config.get("openai", {}).get("api_key"):
os.environ["OPENAI_API_KEY"] = self.config["openai"]["api_key"] os.environ["OPENAI_API_KEY"] = self.config["openai"]["api_key"]
self.set_client(model) self.set_client()
@property @property
@@ -123,12 +118,14 @@ class OpenAIClient:
status=status, status=status,
) )
def set_client(self, model:str, max_token_length:int=None): def set_client(self, max_token_length:int=None):
if not self.openai_api_key: if not self.openai_api_key:
log.error("No OpenAI API key set") log.error("No OpenAI API key set")
return return
model = self.model_name
self.client = AsyncOpenAI() self.client = AsyncOpenAI()
if model == "gpt-3.5-turbo": if model == "gpt-3.5-turbo":
self.max_token_length = min(max_token_length or 4096, 4096) self.max_token_length = min(max_token_length or 4096, 4096)
@@ -144,89 +141,72 @@ class OpenAIClient:
def reconfigure(self, **kwargs): def reconfigure(self, **kwargs):
if "model" in kwargs: if "model" in kwargs:
self.model_name = kwargs["model"] self.model_name = kwargs["model"]
self.set_client(self.model_name, kwargs.get("max_token_length")) self.set_client(kwargs.get("max_token_length"))
def count_tokens(self, content: str):
return num_tokens_from_messages([{"content": content}], model=self.model_name)
async def status(self): async def status(self):
self.emit_status() self.emit_status()
def get_system_message(self, kind: str) -> str:
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
return system_prompts.NARRATOR
if "director" in kind:
return system_prompts.DIRECTOR
if "create" in kind:
return system_prompts.CREATOR
if "roleplay" in kind:
return system_prompts.ROLEPLAY
if "conversation" in kind:
return system_prompts.ROLEPLAY
if "editor" in kind:
return system_prompts.EDITOR
if "world_state" in kind:
return system_prompts.WORLD_STATE
if "analyst" in kind:
return system_prompts.ANALYST
if "analyze" in kind:
return system_prompts.ANALYST
return system_prompts.BASIC
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
) -> str:
right = ""
opts = {}
def prompt_template(self, system_message:str, prompt:str):
# only gpt-4-1106-preview supports json_object response coersion # only gpt-4-1106-preview supports json_object response coersion
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
if "<|BOT|>" in prompt: if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1) _, right = prompt.split("<|BOT|>", 1)
if right: if right:
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ") prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
expected_response = prompt.split("\nContinue this response: ")[1].strip()
if expected_response.startswith("{") and supports_json_object:
opts["response_format"] = {"type": "json_object"}
else: else:
prompt = prompt.replace("<|BOT|>", "") prompt = prompt.replace("<|BOT|>", "")
self.emit_status(processing=True) return prompt
await asyncio.sleep(0.1)
sys_message = {'role': 'system', 'content': self.get_system_message(kind)} def tune_prompt_parameters(self, parameters:dict, kind:str):
super().tune_prompt_parameters(parameters, kind)
human_message = {'role': 'user', 'content': prompt} keys = list(parameters.keys())
log.debug("openai send", kind=kind, sys_message=sys_message, opts=opts) valid_keys = ["temperature", "top_p"]
time_start = time.time() for key in keys:
if key not in valid_keys:
del parameters[key]
response = await self.client.chat.completions.create(model=self.model_name, messages=[sys_message, human_message], **opts) async def generate(self, prompt:str, parameters:dict, kind:str):
time_end = time.time() """
Generates text from the given prompt and parameters.
"""
response = response.choices[0].message.content # only gpt-4-1106-preview supports json_object response coersion
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
right = None
try:
_, right = prompt.split("\nContinue this response: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except IndexError:
pass
if right and response.startswith(right): human_message = {'role': 'user', 'content': prompt.strip()}
response = response[len(right):].strip() system_message = {'role': 'system', 'content': self.get_system_message(kind)}
if kind == "conversation": self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
response = response.replace("\n", " ").strip()
log.debug("openai response", response=response) try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[system_message, human_message], **parameters
)
emit("prompt_sent", data={ response = response.choices[0].message.content
"kind": kind,
"prompt": prompt,
"response": response,
"prompt_tokens": num_tokens_from_messages([sys_message, human_message], model=self.model_name),
"response_tokens": num_tokens_from_messages([{"role": "assistant", "content": response}], model=self.model_name),
"time": time_end - time_start,
})
self.emit_status(processing=False) if right and response.startswith(right):
return response response = response[len(right):].strip()
return response
except Exception as e:
self.log.error("generate error", e=e)
return ""

View File

@@ -0,0 +1,163 @@
__all__ = [
"configure",
"set_max_tokens",
"set_preset",
"preset_for_kind",
"max_tokens_for_kind",
"PRESET_TALEMATE_CONVERSATION",
"PRESET_TALEMATE_CREATOR",
"PRESET_LLAMA_PRECISE",
"PRESET_DIVINE_INTELLECT",
"PRESET_SIMPLE_1",
]
PRESET_TALEMATE_CONVERSATION = {
"temperature": 0.65,
"top_p": 0.47,
"top_k": 42,
"repetition_penalty": 1.18,
"repetition_penalty_range": 2048,
}
PRESET_TALEMATE_CREATOR = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 20,
"repetition_penalty": 1.15,
"repetition_penalty_range": 512,
}
PRESET_LLAMA_PRECISE = {
'temperature': 0.7,
'top_p': 0.1,
'top_k': 40,
'repetition_penalty': 1.18,
}
PRESET_DIVINE_INTELLECT = {
'temperature': 1.31,
'top_p': 0.14,
'top_k': 49,
"repetition_penalty_range": 1024,
'repetition_penalty': 1.17,
}
PRESET_SIMPLE_1 = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 20,
"repetition_penalty": 1.15,
}
def configure(config:dict, kind:str, total_budget:int):
"""
Sets the config based on the kind of text to generate.
"""
set_preset(config, kind)
set_max_tokens(config, kind, total_budget)
return config
def set_max_tokens(config:dict, kind:str, total_budget:int):
"""
Sets the max_tokens in the config based on the kind of text to generate.
"""
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
return config
def set_preset(config:dict, kind:str):
"""
Sets the preset in the config based on the kind of text to generate.
"""
config.update(preset_for_kind(kind))
def preset_for_kind(kind: str):
if kind == "conversation":
return PRESET_TALEMATE_CONVERSATION
elif kind == "conversation_old":
return PRESET_TALEMATE_CONVERSATION # Assuming old conversation uses the same preset
elif kind == "conversation_long":
return PRESET_TALEMATE_CONVERSATION # Assuming long conversation uses the same preset
elif kind == "conversation_select_talking_actor":
return PRESET_TALEMATE_CONVERSATION # Assuming select talking actor uses the same preset
elif kind == "summarize":
return PRESET_LLAMA_PRECISE
elif kind == "analyze":
return PRESET_SIMPLE_1
elif kind == "analyze_creative":
return PRESET_DIVINE_INTELLECT
elif kind == "analyze_long":
return PRESET_SIMPLE_1 # Assuming long analysis uses the same preset as simple
elif kind == "analyze_freeform":
return PRESET_LLAMA_PRECISE
elif kind == "analyze_freeform_short":
return PRESET_LLAMA_PRECISE # Assuming short freeform analysis uses the same preset as precise
elif kind == "narrate":
return PRESET_LLAMA_PRECISE
elif kind == "story":
return PRESET_DIVINE_INTELLECT
elif kind == "create":
return PRESET_TALEMATE_CREATOR
elif kind == "create_concise":
return PRESET_TALEMATE_CREATOR # Assuming concise creation uses the same preset as creator
elif kind == "create_precise":
return PRESET_LLAMA_PRECISE
elif kind == "director":
return PRESET_SIMPLE_1
elif kind == "director_short":
return PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
elif kind == "director_yesno":
return PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
elif kind == "edit_dialogue":
return PRESET_DIVINE_INTELLECT
elif kind == "edit_add_detail":
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
elif kind == "edit_fix_exposition":
return PRESET_DIVINE_INTELLECT # Assuming fixing exposition uses the same preset as divine intellect
else:
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
def max_tokens_for_kind(kind: str, total_budget: int):
if kind == "conversation":
return 75 # Example value, adjust as needed
elif kind == "conversation_old":
return 75 # Example value, adjust as needed
elif kind == "conversation_long":
return 300 # Example value, adjust as needed
elif kind == "conversation_select_talking_actor":
return 30 # Example value, adjust as needed
elif kind == "summarize":
return 500 # Example value, adjust as needed
elif kind == "analyze":
return 500 # Example value, adjust as needed
elif kind == "analyze_creative":
return 1024 # Example value, adjust as needed
elif kind == "analyze_long":
return 2048 # Example value, adjust as needed
elif kind == "analyze_freeform":
return 500 # Example value, adjust as needed
elif kind == "analyze_freeform_short":
return 10 # Example value, adjust as needed
elif kind == "narrate":
return 500 # Example value, adjust as needed
elif kind == "story":
return 300 # Example value, adjust as needed
elif kind == "create":
return min(1024, int(total_budget * 0.35)) # Example calculation, adjust as needed
elif kind == "create_concise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
elif kind == "create_precise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
elif kind == "director":
return min(600, int(total_budget * 0.25)) # Example calculation, adjust as needed
elif kind == "director_short":
return 25 # Example value, adjust as needed
elif kind == "director_yesno":
return 2 # Example value, adjust as needed
elif kind == "edit_dialogue":
return 100 # Example value, adjust as needed
elif kind == "edit_add_detail":
return 200 # Example value, adjust as needed
elif kind == "edit_fix_exposition":
return 1024 # Example value, adjust as needed
else:
return 150 # Default value if none of the kinds match

View File

@@ -67,9 +67,9 @@ def _client_bootstrap(client_type: ClientType, pod):
id = pod["id"] id = pod["id"]
if client_type == ClientType.textgen: if client_type == ClientType.textgen:
api_url = f"https://{id}-5000.proxy.runpod.net/api" api_url = f"https://{id}-5000.proxy.runpod.net"
elif client_type == ClientType.automatic1111: elif client_type == ClientType.automatic1111:
api_url = f"https://{id}-5000.proxy.runpod.net/api" api_url = f"https://{id}-5000.proxy.runpod.net"
return ClientBootstrap( return ClientBootstrap(
client_type=client_type, client_type=client_type,

View File

@@ -1,735 +1,65 @@
import asyncio from talemate.client.base import ClientBase, STOPPING_STRINGS
import random
import json
import copy
import structlog
import time
import httpx
from abc import ABC, abstractmethod
from typing import Callable, Union
import logging
import talemate.util as util
from talemate.client.registry import register from talemate.client.registry import register
import talemate.client.system_prompts as system_prompts from openai import AsyncOpenAI
from talemate.emit import Emission, emit import httpx
from talemate.client.context import client_context_attribute import copy
from talemate.client.model_prompts import model_prompt import random
import talemate.instance as instance
log = structlog.get_logger(__name__)
__all__ = [
"TaleMateClient",
"RestApiTaleMateClient",
"TextGeneratorWebuiClient",
]
# Set up logging level for httpx to WARNING to suppress debug logs.
logging.getLogger('httpx').setLevel(logging.WARNING)
class DefaultContext(int):
pass
PRESET_TALEMATE_LEGACY = {
"temperature": 0.72,
"top_p": 0.73,
"top_k": 0,
"top_a": 0,
"repetition_penalty": 1.18,
"repetition_penalty_range": 2048,
"encoder_repetition_penalty": 1,
#"encoder_repetition_penalty": 1.2,
#"no_repeat_ngram_size": 2,
"do_sample": True,
"length_penalty": 1,
}
PRESET_TALEMATE_CONVERSATION = {
"temperature": 0.65,
"top_p": 0.47,
"top_k": 42,
"typical_p": 1,
"top_a": 0,
"tfs": 1,
"epsilon_cutoff": 0,
"eta_cutoff": 0,
"repetition_penalty": 1.18,
"repetition_penalty_range": 2048,
"no_repeat_ngram_size": 0,
"penalty_alpha": 0,
"num_beams": 1,
"length_penalty": 1,
"min_length": 0,
"encoder_rep_pen": 1,
"do_sample": True,
"early_stopping": False,
"mirostat_mode": 0,
"mirostat_tau": 5,
"mirostat_eta": 0.1
}
PRESET_TALEMATE_CREATOR = {
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.15,
"repetition_penalty_range": 512,
"top_k": 20,
"do_sample": True,
"length_penalty": 1,
}
PRESET_LLAMA_PRECISE = {
'temperature': 0.7,
'top_p': 0.1,
'repetition_penalty': 1.18,
'top_k': 40
}
PRESET_KOBOLD_GODLIKE = {
'temperature': 0.7,
'top_p': 0.5,
'typical_p': 0.19,
'repetition_penalty': 1.1,
"repetition_penalty_range": 1024,
}
PRESET_DIVINE_INTELLECT = {
'temperature': 1.31,
'top_p': 0.14,
"repetition_penalty_range": 1024,
'repetition_penalty': 1.17,
'top_k': 49,
"mirostat_mode": 0,
"mirostat_tau": 5,
"mirostat_eta": 0.1,
"tfs": 1,
}
PRESET_SIMPLE_1 = {
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.15,
"top_k": 20,
}
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
copied_config = copy.deepcopy(prompt_config)
min_offset = offset * 0.3
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
return copied_config
class TaleMateClient:
"""
An abstract TaleMate client that can be implemented for different communication methods with the AI.
"""
def __init__(
self,
api_url: str,
max_token_length: Union[int, DefaultContext] = int.__new__(
DefaultContext, 2048
),
):
self.api_url = api_url
self.name = "generic_client"
self.model_name = None
self.last_token_length = 0
self.max_token_length = max_token_length
self.original_max_token_length = max_token_length
self.enabled = True
self.current_status = None
@abstractmethod
def send_message(self, message: dict) -> str:
"""
Sends a message to the AI. Needs to be implemented by the subclass.
:param message: The message to be sent.
:return: The AI's response text.
"""
pass
@abstractmethod
def send_prompt(self, prompt: str) -> str:
"""
Sends a prompt to the AI. Needs to be implemented by the subclass.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
pass
def reconfigure(self, **kwargs):
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
def remaining_tokens(self, context: Union[str, list]) -> int:
return self.max_token_length - util.count_tokens(context)
def prompt_template(self, sys_msg, prompt):
return model_prompt(self.model_name, sys_msg, prompt)
class RESTTaleMateClient(TaleMateClient, ABC):
"""
A RESTful TaleMate client that connects to the REST API endpoint.
"""
async def send_message(self, message: dict, url: str) -> str:
"""
Sends a message to the REST API and returns the AI's response.
:param message: The message to be sent.
:return: The AI's response text.
"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(url, json=message, timeout=None)
response_data = response.json()
return response_data["results"][0]["text"]
except KeyError:
return response_data["results"][0]["history"]["visible"][0][-1]
@register() @register()
class TextGeneratorWebuiClient(RESTTaleMateClient): class TextGeneratorWebuiClient(ClientBase):
"""
Client that connects to the text-generatior-webui api
"""
client_type = "textgenwebui" client_type = "textgenwebui"
conversation_retries = 5
def __init__(self, api_url: str, max_token_length: int = 2048, **kwargs): def tune_prompt_parameters(self, parameters:dict, kind:str):
super().tune_prompt_parameters(parameters, kind)
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
# is this needed?
parameters["max_new_tokens"] = parameters["max_tokens"]
api_url = self.cleanup_api_url(api_url) def set_client(self):
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
self.api_url_base = api_url async def get_model_name(self):
api_url = f"{api_url}/v1/chat" async with httpx.AsyncClient() as client:
super().__init__(api_url, max_token_length=max_token_length) response = await client.get(f"{self.api_url}/v1/internal/model/info", timeout=2)
self.model_name = None if response.status_code == 404:
self.limited_ram = False raise Exception("Could not find model info (wrong api version?)")
self.name = kwargs.get("name", "textgenwebui") response_data = response.json()
self.processing = False model_name = response_data.get("model_name")
self.connected = False
def __str__(self): if model_name == "None":
return f"TextGeneratorWebuiClient[{self.api_url_base}][{self.model_name or ''}]" model_name = None
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
def cleanup_api_url(self, api_url:str):
""" """
Strips trailing / and ensures endpoint is /api Generates text from the given prompt and parameters.
""" """
if api_url.endswith("/"): headers = {}
api_url = api_url[:-1] headers["Content-Type"] = "application/json"
if not api_url.endswith("/api"): parameters["prompt"] = prompt.strip()
api_url = api_url + "/api"
return api_url async with httpx.AsyncClient() as client:
response = await client.post(f"{self.api_url}/v1/completions", json=parameters, timeout=None, headers=headers)
def reconfigure(self, **kwargs):
super().reconfigure(**kwargs)
if "api_url" in kwargs:
log.debug("reconfigure", api_url=kwargs["api_url"])
api_url = kwargs["api_url"]
api_url = self.cleanup_api_url(api_url)
self.api_url_base = api_url
self.api_url = api_url
def toggle_disabled_if_remote(self):
remote_servies = [
".runpod.net"
]
for service in remote_servies:
if service in self.api_url_base:
self.enabled = False
return
def emit_status(self, processing: bool = None):
if processing is not None:
self.processing = processing
if not self.enabled:
status = "disabled"
model_name = "Disabled"
elif not self.connected:
status = "error"
model_name = "Could not connect"
elif self.model_name:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
model_name = "No model loaded"
status = "warning"
status_change = status != self.current_status
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
)
if status_change:
instance.emit_agent_status_by_client(self)
# Add the 'status' method
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
Raises an error if no model name is returned.
:return: None
"""
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_url_base}/v1/model", timeout=2)
except (
httpx.TimeoutException,
httpx.NetworkError,
):
self.model_name = None
self.connected = False
self.toggle_disabled_if_remote()
self.emit_status()
return
self.connected = True
try:
response_data = response.json() response_data = response.json()
self.enabled = True return response_data["choices"][0]["text"]
except json.decoder.JSONDecodeError as e:
self.connected = False
self.toggle_disabled_if_remote()
if not self.enabled:
log.warn("remote service unreachable, disabling client", name=self.name)
else:
log.error("client response error", name=self.name, e=e)
self.emit_status() def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
return
model_name = response_data.get("result")
if not model_name or model_name == "None":
log.warning("client model not loaded", client=self.name)
self.emit_status()
return
model_changed = model_name != self.model_name
self.model_name = model_name
if model_changed:
self.auto_context_length()
log.info(f"{self} [{self.max_token_length} ctx]: ready")
self.emit_status()
def auto_context_length(self):
""" """
Automaticalle sets context length based on LLM adjusts temperature and repetition_penalty
by random values using the base value as a center
""" """
if not isinstance(self.max_token_length, DefaultContext): temp = prompt_config["temperature"]
# context length was specified manually rep_pen = prompt_config["repetition_penalty"]
return
model_name = self.model_name.lower() min_offset = offset * 0.3
if "longchat" in model_name: prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
self.max_token_length = 16000 prompt_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
elif "8k" in model_name:
if not self.limited_ram or "13b" in model_name:
self.max_token_length = 6000
else:
self.max_token_length = 4096
elif "4k" in model_name:
self.max_token_length = 4096
else:
self.max_token_length = self.original_max_token_length
@property
def instruction_template(self):
if "vicuna" in self.model_name.lower():
return "Vicuna-v1.1"
if "camel" in self.model_name.lower():
return "Vicuna-v1.1"
return ""
def prompt_url(self):
return self.api_url_base + "/v1/generate"
def prompt_config_conversation_old(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.BASIC,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 75,
"truncation_length": self.max_token_length,
}
config.update(PRESET_TALEMATE_CONVERSATION)
return config
def prompt_config_conversation(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ROLEPLAY,
prompt,
)
stopping_strings = ["<|end_of_turn|>"]
conversation_context = client_context_attribute("conversation")
stopping_strings += [
f"{character}:" for character in conversation_context["other_characters"]
]
max_new_tokens = conversation_context.get("length", 96)
log.debug("prompt_config_conversation", stopping_strings=stopping_strings, conversation_context=conversation_context, max_new_tokens=max_new_tokens)
config = {
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"truncation_length": self.max_token_length,
"stopping_strings": stopping_strings,
}
config.update(PRESET_TALEMATE_CONVERSATION)
jiggle_randomness(config)
return config
def prompt_config_conversation_long(self, prompt: str) -> dict:
config = self.prompt_config_conversation(prompt)
config["max_new_tokens"] = 300
return config
def prompt_config_conversation_select_talking_actor(self, prompt: str) -> dict:
config = self.prompt_config_conversation(prompt)
config["max_new_tokens"] = 30
config["stopping_strings"] += [":"]
return config
def prompt_config_summarize(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.NARRATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_analyze(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ANALYST,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_SIMPLE_1)
return config
def prompt_config_analyze_creative(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ANALYST,
prompt,
)
config = {}
config.update(PRESET_DIVINE_INTELLECT)
config.update({
"prompt": prompt,
"max_new_tokens": 1024,
"repetition_penalty_range": 1024,
"truncation_length": self.max_token_length
})
return config
def prompt_config_analyze_long(self, prompt: str) -> dict:
config = self.prompt_config_analyze(prompt)
config["max_new_tokens"] = 2048
return config
def prompt_config_analyze_freeform(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ANALYST_FREEFORM,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_analyze_freeform_short(self, prompt: str) -> dict:
config = self.prompt_config_analyze_freeform(prompt)
config["max_new_tokens"] = 10
return config
def prompt_config_narrate(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.NARRATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_story(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.NARRATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 300,
"seed": random.randint(0, 1000000000),
"truncation_length": self.max_token_length
}
config.update(PRESET_DIVINE_INTELLECT)
config.update({
"repetition_penalty": 1.3,
"repetition_penalty_range": 2048,
})
return config
def prompt_config_create(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.CREATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": min(1024, self.max_token_length * 0.35),
"truncation_length": self.max_token_length,
}
config.update(PRESET_TALEMATE_CREATOR)
return config
def prompt_config_create_concise(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.CREATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": min(400, self.max_token_length * 0.25),
"truncation_length": self.max_token_length,
"stopping_strings": ["<|DONE|>", "\n\n"]
}
config.update(PRESET_TALEMATE_CREATOR)
return config
def prompt_config_create_precise(self, prompt: str) -> dict:
config = self.prompt_config_create_concise(prompt)
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_director(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.DIRECTOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": min(600, self.max_token_length * 0.25),
"truncation_length": self.max_token_length,
}
config.update(PRESET_SIMPLE_1)
return config
def prompt_config_director_short(self, prompt: str) -> dict:
config = self.prompt_config_director(prompt)
config.update(max_new_tokens=25)
return config
def prompt_config_director_yesno(self, prompt: str) -> dict:
config = self.prompt_config_director(prompt)
config.update(max_new_tokens=2)
return config
def prompt_config_edit_dialogue(self, prompt:str) -> dict:
prompt = self.prompt_template(
system_prompts.EDITOR,
prompt,
)
conversation_context = client_context_attribute("conversation")
stopping_strings = [
f"{character}:" for character in conversation_context["other_characters"]
]
config = {
"prompt": prompt,
"max_new_tokens": 100,
"truncation_length": self.max_token_length,
"stopping_strings": stopping_strings,
}
config.update(PRESET_DIVINE_INTELLECT)
return config
def prompt_config_edit_add_detail(self, prompt:str) -> dict:
config = self.prompt_config_edit_dialogue(prompt)
config.update(max_new_tokens=200)
return config
def prompt_config_edit_fix_exposition(self, prompt:str) -> dict:
config = self.prompt_config_edit_dialogue(prompt)
config.update(max_new_tokens=1024)
return config
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
) -> str:
"""
Send a prompt to the AI and return its response.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
#prompt = prompt.replace("<|BOT|>", "<|BOT|>Certainly! ")
await self.status()
self.emit_status(processing=True)
await asyncio.sleep(0.01)
fn_prompt_config = getattr(self, f"prompt_config_{kind}")
fn_url = self.prompt_url
message = fn_prompt_config(prompt)
if client_context_attribute("nuke_repetition") > 0.0 and kind in ["conversation", "story"]:
log.info("nuke repetition", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
message = jiggle_randomness(message, offset=client_context_attribute("nuke_repetition"))
log.info("nuke repetition (applied)", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
message = finalize(message)
token_length = int(len(message["prompt"]) / 3.6)
self.last_token_length = token_length
log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length)
message["prompt"] = message["prompt"].strip()
#print(f"prompt: |{message['prompt']}|")
# add <|im_end|> to stopping strings
if "stopping_strings" in message:
message["stopping_strings"] += ["<|im_end|>", "</s>"]
else:
message["stopping_strings"] = ["<|im_end|>", "</s>"]
#message["seed"] = -1
#for k,v in message.items():
# if k == "prompt":
# continue
# print(f"{k}: {v}")
time_start = time.time()
response = await self.send_message(message, fn_url())
time_end = time.time()
response = response.split("#")[0]
self.emit_status(processing=False)
emit("prompt_sent", data={
"kind": kind,
"prompt": message["prompt"],
"response": response,
"prompt_tokens": token_length,
"response_tokens": int(len(response) / 3.6),
"time": time_end - time_start,
})
return response
class OpenAPIClient(RESTTaleMateClient):
pass
class GPT3Client(OpenAPIClient):
pass
class GPT4Client(OpenAPIClient):
pass

View File

@@ -0,0 +1,32 @@
import copy
import random
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
copied_config = copy.deepcopy(prompt_config)
min_offset = offset * 0.3
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
return copied_config
def jiggle_enabled_for(kind:str):
if kind in ["conversation", "story"]:
return True
if kind.startswith("narrate"):
return True
return False

View File

@@ -23,6 +23,7 @@ from .cmd_save_as import CmdSaveAs
from .cmd_save_characters import CmdSaveCharacters from .cmd_save_characters import CmdSaveCharacters
from .cmd_setenv import CmdSetEnvironmentToScene, CmdSetEnvironmentToCreative from .cmd_setenv import CmdSetEnvironmentToScene, CmdSetEnvironmentToCreative
from .cmd_time_util import * from .cmd_time_util import *
from .cmd_tts import *
from .cmd_world_state import CmdWorldState from .cmd_world_state import CmdWorldState
from .cmd_run_helios_test import CmdHeliosTest from .cmd_run_helios_test import CmdHeliosTest
from .manager import Manager from .manager import Manager

View File

@@ -32,4 +32,5 @@ class CmdRebuildArchive(TalemateCommand):
if not more: if not more:
break break
self.scene.sync_time()
await self.scene.commit_to_memory() await self.scene.commit_to_memory()

View File

@@ -17,7 +17,26 @@ class CmdRename(TalemateCommand):
aliases = [] aliases = []
async def run(self): async def run(self):
# collect list of characters in the scene
if self.args:
character_name = self.args[0]
else:
character_names = self.scene.character_names
character_name = await wait_for_input("Which character do you want to rename?", data={
"input_type": "select",
"choices": character_names,
})
character = self.scene.get_character(character_name)
if not character:
self.system_message(f"Character {character_name} not found")
return True
name = await wait_for_input("Enter new name: ") name = await wait_for_input("Enter new name: ")
self.scene.main_character.character.rename(name) character.rename(name)
await asyncio.sleep(0) await asyncio.sleep(0)
return True

View File

@@ -0,0 +1,33 @@
import asyncio
import logging
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.prompts.base import set_default_sectioning_handler
from talemate.instance import get_agent
__all__ = [
"CmdTestTTS",
]
@register
class CmdTestTTS(TalemateCommand):
"""
Command class for the 'test_tts' command
"""
name = "test_tts"
description = "Test the TTS agent"
aliases = []
async def run(self):
tts_agent = get_agent("tts")
try:
last_message = str(self.scene.history[-1])
except IndexError:
last_message = "Welcome to talemate!"
await tts_agent.generate(last_message)

View File

@@ -66,6 +66,21 @@ class OpenAIConfig(BaseModel):
class RunPodConfig(BaseModel): class RunPodConfig(BaseModel):
api_key: Union[str,None]=None api_key: Union[str,None]=None
class ElevenLabsConfig(BaseModel):
api_key: Union[str,None]=None
class CoquiConfig(BaseModel):
api_key: Union[str,None]=None
class TTSVoiceSamples(BaseModel):
label:str
value:str
class TTSConfig(BaseModel):
device:str = "cuda"
model:str = "tts_models/multilingual/multi-dataset/xtts_v2"
voices: list[TTSVoiceSamples] = pydantic.Field(default_factory=list)
class ChromaDB(BaseModel): class ChromaDB(BaseModel):
instructor_device: str="cpu" instructor_device: str="cpu"
instructor_model: str="default" instructor_model: str="default"
@@ -85,6 +100,12 @@ class Config(BaseModel):
chromadb: ChromaDB = ChromaDB() chromadb: ChromaDB = ChromaDB()
elevenlabs: ElevenLabsConfig = ElevenLabsConfig()
coqui: CoquiConfig = CoquiConfig()
tts: TTSConfig = TTSConfig()
class Config: class Config:
extra = "ignore" extra = "ignore"

View File

@@ -24,6 +24,8 @@ CommandStatus = signal("command_status")
WorldState = signal("world_state") WorldState = signal("world_state")
ArchivedHistory = signal("archived_history") ArchivedHistory = signal("archived_history")
AudioQueue = signal("audio_queue")
MessageEdited = signal("message_edited") MessageEdited = signal("message_edited")
handlers = { handlers = {
@@ -46,4 +48,5 @@ handlers = {
"archived_history": ArchivedHistory, "archived_history": ArchivedHistory,
"message_edited": MessageEdited, "message_edited": MessageEdited,
"prompt_sent": PromptSent, "prompt_sent": PromptSent,
"audio_queue": AudioQueue,
} }

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from talemate.tale_mate import Scene from talemate.tale_mate import Scene, Actor, SceneMessage
__all__ = [ __all__ = [
"Event", "Event",
@@ -43,3 +43,11 @@ class GameLoopEvent(Event):
@dataclass @dataclass
class GameLoopStartEvent(GameLoopEvent): class GameLoopStartEvent(GameLoopEvent):
pass pass
@dataclass
class GameLoopActorIterEvent(GameLoopEvent):
actor: Actor
@dataclass
class GameLoopNewMessageEvent(GameLoopEvent):
message: SceneMessage

View File

@@ -190,7 +190,10 @@ async def load_scene_from_data(
await scene.add_actor(actor) await scene.add_actor(actor)
if scene.environment != "creative": if scene.environment != "creative":
await scene.world_state.request_update(initial_only=True) try:
await scene.world_state.request_update(initial_only=True)
except Exception as e:
log.error("world_state.request_update", error=e)
# the scene has been saved before (since we just loaded it), so we set the saved flag to True # the scene has been saved before (since we just loaded it), so we set the saved flag to True
# as long as the scene has a memory_id. # as long as the scene has a memory_id.

View File

@@ -290,6 +290,7 @@ class Prompt:
env.globals["query_scene"] = self.query_scene env.globals["query_scene"] = self.query_scene
env.globals["query_memory"] = self.query_memory env.globals["query_memory"] = self.query_memory
env.globals["query_text"] = self.query_text env.globals["query_text"] = self.query_text
env.globals["instruct_text"] = self.instruct_text
env.globals["retrieve_memories"] = self.retrieve_memories env.globals["retrieve_memories"] = self.retrieve_memories
env.globals["uuidgen"] = lambda: str(uuid.uuid4()) env.globals["uuidgen"] = lambda: str(uuid.uuid4())
env.globals["to_int"] = lambda x: int(x) env.globals["to_int"] = lambda x: int(x)
@@ -394,9 +395,14 @@ class Prompt:
f"Answer: " + loop.run_until_complete(memory.query(query, **kwargs)), f"Answer: " + loop.run_until_complete(memory.query(query, **kwargs)),
]) ])
else: else:
return loop.run_until_complete(memory.multi_query([query], **kwargs)) return loop.run_until_complete(memory.multi_query(query.split("\n"), **kwargs))
def instruct_text(self, instruction:str, text:str):
loop = asyncio.get_event_loop()
world_state = instance.get_agent("world_state")
instruction = instruction.format(**self.vars)
return loop.run_until_complete(world_state.analyze_and_follow_instruction(text, instruction))
def retrieve_memories(self, lines:list[str], goal:str=None): def retrieve_memories(self, lines:list[str], goal:str=None):
@@ -467,8 +473,6 @@ class Prompt:
# remove all duplicate whitespace # remove all duplicate whitespace
cleaned = re.sub(r"\s+", " ", cleaned) cleaned = re.sub(r"\s+", " ", cleaned)
print("set_json_response", cleaned)
return self.set_prepared_response(cleaned) return self.set_prepared_response(cleaned)

View File

@@ -0,0 +1,19 @@
{% block rendered_context -%}
<|SECTION:CONTEXT|>
Content Context: This is a specific scene from {{ scene.context }}
Scenario Premise: {{ scene.description }}
{% for memory in query_memory(last_line, as_question_answer=False, iterate=10) -%}
{{ memory }}
{% endfor %}
{% endblock -%}
<|CLOSE_SECTION|>
{% for scene_context in scene.context_history(budget=max_tokens-200-count_tokens(self.rendered_context())) -%}
{{ scene_context }}
{% endfor %}
<|SECTION:TASK|>
Based on the previous line '{{ last_line }}', create the next line of narration. This line should focus solely on describing sensory details (like sounds, sights, smells, tactile sensations) or external actions that move the story forward. Avoid including any character's internal thoughts, feelings, or dialogue. Your narration should directly respond to '{{ last_line }}', either by elaborating on the immediate scene or by subtly advancing the plot. Generate exactly one sentence of new narration. If the character is trying to determine some state, truth or situation, try to answer as part of the narration.
Be creative and generate something new and interesting, but stay true to the setting and context of the story so far.
<|CLOSE_SECTION|>
{{ set_prepared_response('*') }}

View File

@@ -8,13 +8,13 @@
{% if query.endswith("?") -%} {% if query.endswith("?") -%}
Question: {{ query }} Question: {{ query }}
Extra context: {{ query_memory(query, as_question_answer=False) }} Extra context: {{ query_memory(query, as_question_answer=False) }}
Instruction: Analyze Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue. Instruction: Analyze Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue.
{% else -%} {% else -%}
Instruction: {{ query }} Instruction: {{ query }}
Extra context: {{ query_memory(query, as_question_answer=False) }} Extra context: {{ query_memory(query, as_question_answer=False) }}
Answer based on Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Answer based on Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context.
{% endif -%} {% endif -%}
Content Context: This is a specific scene from {{ scene.context }} Content Context: This is a specific scene from {{ scene.context }}
Narration style: point and click adventure game from the 90s Your answer should be in the style of short narration that fits the context of the scene.
<|CLOSE_SECTION|> <|CLOSE_SECTION|>
Narrator answers: {% if at_the_end %}{{ bot_token }}At the end of the dialogue, {% endif %} Narrator answers: {% if at_the_end %}{{ bot_token }}At the end of the dialogue, {% endif %}

View File

@@ -8,9 +8,10 @@
<|SECTION:TASK|> <|SECTION:TASK|>
Answer the following questions: Answer the following questions:
{{ query_text("What are 1 to 3 questions to ask the narrator of the story to gather more context from the past for the continuation of this conversation? If a character is asking about a status, location or information about an item or another character, make sure to include question(s) that help gather context for this. Don't explain your reasoning. Don't ask the actors directly.", text, as_question_answer=False) }} {{ instruct_text("Ask the narrator three (3) questions to gather more context from the past for the continuation of this conversation. If a character is asking about a state, location or information about an item or another character, make sure to include question(s) that help gather context for this.", text) }}
You answers should be precise, truthful and short. You answers should be precise, truthful and short. Pay close attention to timestamps when retrieving information from the context.
<|CLOSE_SECTION|> <|CLOSE_SECTION|>
<|SECTION:RELEVANT CONTEXT|> <|SECTION:RELEVANT CONTEXT|>
{{ bot_token }}Answers:

View File

@@ -0,0 +1,5 @@
{{ text }}
<|SECTION:TASK|>
{{ instruction }}

View File

@@ -34,7 +34,7 @@ No dialogue so far
{% endif -%} {% endif -%}
<|CLOSE_SECTION|> <|CLOSE_SECTION|>
<|SECTION:SCENE PROGRESS|> <|SECTION:SCENE PROGRESS|>
{% for scene_context in scene.context_history(budget=300, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%} {% for scene_context in scene.context_history(budget=500, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%}
{{ scene_context }} {{ scene_context }}
{% endfor -%} {% endfor -%}
<|CLOSE_SECTION|> <|CLOSE_SECTION|>

View File

@@ -110,7 +110,6 @@ async def websocket_endpoint(websocket, path):
elif action_type == "request_scenes_list": elif action_type == "request_scenes_list":
query = data.get("query", "") query = data.get("query", "")
handler.request_scenes_list(query) handler.request_scenes_list(query)
elif action_type == "configure_clients": elif action_type == "configure_clients":
handler.configure_clients(data.get("clients")) handler.configure_clients(data.get("clients"))
elif action_type == "configure_agents": elif action_type == "configure_agents":

View File

@@ -1,3 +1,5 @@
import os
import argparse import argparse
import asyncio import asyncio
import sys import sys

View File

@@ -0,0 +1,26 @@
import structlog
import talemate.instance as instance
log = structlog.get_logger("talemate.server.tts")
class TTSPlugin:
router = "tts"
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
self.tts = None
async def handle(self, data:dict):
action = data.get("action")
if action == "test":
return await self.handle_test(data)
async def handle_test(self, data:dict):
tts_agent = instance.get_agent("tts")
await tts_agent.generate("Welcome to talemate!")

View File

@@ -91,7 +91,7 @@ class WebsocketHandler(Receiver):
for agent_typ, agent_config in self.agents.items(): for agent_typ, agent_config in self.agents.items():
try: try:
client = self.llm_clients.get(agent_config.get("client"))["client"] client = self.llm_clients.get(agent_config.get("client"))["client"]
except TypeError: except TypeError as e:
client = None client = None
if not client: if not client:
@@ -167,19 +167,28 @@ class WebsocketHandler(Receiver):
log.info("Configuring clients", clients=clients) log.info("Configuring clients", clients=clients)
for client in clients: for client in clients:
if client["type"] == "textgenwebui":
client.pop("status", None)
if client["type"] in ["textgenwebui", "lmstudio"]:
try: try:
max_token_length = int(client.get("max_token_length", 2048)) max_token_length = int(client.get("max_token_length", 2048))
except ValueError: except ValueError:
continue continue
client.pop("model", None)
self.llm_clients[client["name"]] = { self.llm_clients[client["name"]] = {
"type": "textgenwebui", "type": client["type"],
"api_url": client["apiUrl"], "api_url": client["apiUrl"],
"name": client["name"], "name": client["name"],
"max_token_length": max_token_length, "max_token_length": max_token_length,
} }
elif client["type"] == "openai": elif client["type"] == "openai":
client.pop("model_name", None)
client.pop("apiUrl", None)
self.llm_clients[client["name"]] = { self.llm_clients[client["name"]] = {
"type": "openai", "type": "openai",
"name": client["name"], "name": client["name"],
@@ -213,16 +222,25 @@ class WebsocketHandler(Receiver):
def configure_agents(self, agents): def configure_agents(self, agents):
self.agents = {typ: {} for typ in instance.agent_types()} self.agents = {typ: {} for typ in instance.agent_types()}
log.debug("Configuring agents", agents=agents) log.debug("Configuring agents")
for agent in agents: for agent in agents:
name = agent["name"] name = agent["name"]
# special case for memory agent # special case for memory agent
if name == "memory": if name == "memory" or name == "tts":
self.agents[name] = { self.agents[name] = {
"name": name, "name": name,
} }
agent_instance = instance.get_agent(name, **self.agents[name])
if agent_instance.has_toggle:
self.agents[name]["enabled"] = agent["enabled"]
if getattr(agent_instance, "actions", None):
self.agents[name]["actions"] = agent.get("actions", {})
agent_instance.apply_config(**self.agents[name])
log.debug("Configured agent", name=name)
continue continue
if name not in self.agents: if name not in self.agents:
@@ -385,7 +403,7 @@ class WebsocketHandler(Receiver):
"status": emission.status, "status": emission.status,
"data": emission.data, "data": emission.data,
"max_token_length": client.max_token_length if client else 2048, "max_token_length": client.max_token_length if client else 2048,
"apiUrl": getattr(client, "api_url_base", None) if client else None, "apiUrl": getattr(client, "api_url", None) if client else None,
} }
) )
@@ -419,6 +437,14 @@ class WebsocketHandler(Receiver):
} }
) )
def handle_audio_queue(self, emission: Emission):
self.queue_put(
{
"type": "audio_queue",
"data": emission.data,
}
)
def handle_request_input(self, emission: Emission): def handle_request_input(self, emission: Emission):
self.waiting_for_input = True self.waiting_for_input = True

View File

@@ -43,6 +43,10 @@ __all__ = [
log = structlog.get_logger("talemate") log = structlog.get_logger("talemate")
async_signals.register("game_loop_start")
async_signals.register("game_loop")
async_signals.register("game_loop_actor_iter")
async_signals.register("game_loop_new_message")
class Character: class Character:
""" """
@@ -523,8 +527,6 @@ class Player(Actor):
return message return message
async_signals.register("game_loop_start")
async_signals.register("game_loop")
class Scene(Emitter): class Scene(Emitter):
""" """
@@ -575,6 +577,8 @@ class Scene(Emitter):
"character_state": signal("character_state"), "character_state": signal("character_state"),
"game_loop": async_signals.get("game_loop"), "game_loop": async_signals.get("game_loop"),
"game_loop_start": async_signals.get("game_loop_start"), "game_loop_start": async_signals.get("game_loop_start"),
"game_loop_actor_iter": async_signals.get("game_loop_actor_iter"),
"game_loop_new_message": async_signals.get("game_loop_new_message"),
} }
self.setup_emitter(scene=self) self.setup_emitter(scene=self)
@@ -702,6 +706,12 @@ class Scene(Emitter):
) )
) )
loop = asyncio.get_event_loop()
for message in messages:
loop.run_until_complete(self.signals["game_loop_new_message"].send(
events.GameLoopNewMessageEvent(scene=self, event_type="game_loop_new_message", message=message)
))
def push_archive(self, entry: data_objects.ArchiveEntry): def push_archive(self, entry: data_objects.ArchiveEntry):
""" """
@@ -1066,7 +1076,9 @@ class Scene(Emitter):
new_message = await narrator.agent.narrate_character(character) new_message = await narrator.agent.narrate_character(character)
elif source == "narrate_query": elif source == "narrate_query":
new_message = await narrator.agent.narrate_query(arg) new_message = await narrator.agent.narrate_query(arg)
elif source == "narrate_dialogue":
character = self.get_character(arg)
new_message = await narrator.agent.narrate_after_dialogue(character)
else: else:
fn = getattr(narrator.agent, source, None) fn = getattr(narrator.agent, source, None)
if not fn: if not fn:
@@ -1172,7 +1184,7 @@ class Scene(Emitter):
}, },
) )
self.log.debug("scene_status", scene=self.name, scene_time=self.ts, saved=self.saved) self.log.debug("scene_status", scene=self.name, scene_time=self.ts, human_ts=util.iso8601_duration_to_human(self.ts, suffix=""), saved=self.saved)
def set_environment(self, environment: str): def set_environment(self, environment: str):
""" """
@@ -1185,6 +1197,7 @@ class Scene(Emitter):
""" """
Accepts an iso6801 duration string and advances the scene's world state by that amount Accepts an iso6801 duration string and advances the scene's world state by that amount
""" """
log.debug("advance_time", ts=ts, scene_ts=self.ts, duration=isodate.parse_duration(ts), scene_duration=isodate.parse_duration(self.ts))
self.ts = isodate.duration_isoformat( self.ts = isodate.duration_isoformat(
isodate.parse_duration(self.ts) + isodate.parse_duration(ts) isodate.parse_duration(self.ts) + isodate.parse_duration(ts)
@@ -1208,8 +1221,11 @@ class Scene(Emitter):
self.ts = self.archived_history[i]["ts"] self.ts = self.archived_history[i]["ts"]
break break
end = self.archived_history[-1].get("end", 0)
else:
end = 0
for message in self.history: for message in self.history[end:]:
if isinstance(message, TimePassageMessage): if isinstance(message, TimePassageMessage):
self.advance_time(message.ts) self.advance_time(message.ts)
@@ -1339,6 +1355,10 @@ class Scene(Emitter):
if await command.execute(message): if await command.execute(message):
break break
await self.call_automated_actions() await self.call_automated_actions()
await self.signals["game_loop_actor_iter"].send(
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
)
continue continue
self.saved = False self.saved = False
@@ -1351,6 +1371,10 @@ class Scene(Emitter):
"character", item, character=actor.character "character", item, character=actor.character
) )
await self.signals["game_loop_actor_iter"].send(
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
)
self.emit_status() self.emit_status()
except TalemateInterrupt: except TalemateInterrupt:

View File

@@ -303,6 +303,9 @@ def strip_partial_sentences(text:str) -> str:
# Sentence ending characters # Sentence ending characters
sentence_endings = ['.', '!', '?', '"', "*"] sentence_endings = ['.', '!', '?', '"', "*"]
if not text:
return text
# Check if the last character is already a sentence ending # Check if the last character is already a sentence ending
if text[-1] in sentence_endings: if text[-1] in sentence_endings:
return text return text
@@ -487,30 +490,43 @@ def clean_attribute(attribute: str) -> str:
def duration_to_timedelta(duration): def duration_to_timedelta(duration):
"""Convert an isodate.Duration object to a datetime.timedelta object.""" """Convert an isodate.Duration object or a datetime.timedelta object to a datetime.timedelta object."""
# Check if the duration is already a timedelta object
if isinstance(duration, datetime.timedelta):
return duration
# Check if the duration is an isodate.Duration object with a tdelta attribute
if hasattr(duration, 'tdelta'):
return duration.tdelta
# If it's an isodate.Duration object with separate year, month, day, hour, minute, second attributes
days = int(duration.years) * 365 + int(duration.months) * 30 + int(duration.days) days = int(duration.years) * 365 + int(duration.months) * 30 + int(duration.days)
return datetime.timedelta(days=days) seconds = int(duration.hours) * 3600 + int(duration.minutes) * 60 + int(duration.seconds)
return datetime.timedelta(days=days, seconds=seconds)
def timedelta_to_duration(delta): def timedelta_to_duration(delta):
"""Convert a datetime.timedelta object to an isodate.Duration object.""" """Convert a datetime.timedelta object to an isodate.Duration object."""
# Extract days and convert to years, months, and days
days = delta.days days = delta.days
years = days // 365 years = days // 365
days %= 365 days %= 365
months = days // 30 months = days // 30
days %= 30 days %= 30
return isodate.duration.Duration(years=years, months=months, days=days) # Convert remaining seconds to hours, minutes, and seconds
seconds = delta.seconds
hours = seconds // 3600
seconds %= 3600
minutes = seconds // 60
seconds %= 60
return isodate.Duration(years=years, months=months, days=days, hours=hours, minutes=minutes, seconds=seconds)
def parse_duration_to_isodate_duration(duration_str): def parse_duration_to_isodate_duration(duration_str):
"""Parse ISO 8601 duration string and ensure the result is an isodate.Duration.""" """Parse ISO 8601 duration string and ensure the result is an isodate.Duration."""
parsed_duration = isodate.parse_duration(duration_str) parsed_duration = isodate.parse_duration(duration_str)
if isinstance(parsed_duration, datetime.timedelta): if isinstance(parsed_duration, datetime.timedelta):
days = parsed_duration.days return timedelta_to_duration(parsed_duration)
years = days // 365
days %= 365
months = days // 30
days %= 30
return isodate.duration.Duration(years=years, months=months, days=days)
return parsed_duration return parsed_duration
def iso8601_diff(duration_str1, duration_str2): def iso8601_diff(duration_str1, duration_str2):
@@ -530,40 +546,50 @@ def iso8601_diff(duration_str1, duration_str2):
return difference return difference
def iso8601_duration_to_human(iso_duration, suffix:str=" ago"): def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
# Parse the ISO8601 duration string into an isodate duration object
if isinstance(iso_duration, isodate.Duration): # Parse the ISO8601 duration string into an isodate duration object
duration = iso_duration if not isinstance(iso_duration, isodate.Duration):
else:
duration = isodate.parse_duration(iso_duration) duration = isodate.parse_duration(iso_duration)
else:
duration = iso_duration
# Extract years, months, days, and the time part as seconds
years, months, days, hours, minutes, seconds = 0, 0, 0, 0, 0, 0
if isinstance(duration, isodate.Duration): if isinstance(duration, isodate.Duration):
years = duration.years years = duration.years
months = duration.months months = duration.months
days = duration.days days = duration.days
seconds = duration.tdelta.total_seconds() hours = duration.tdelta.seconds // 3600
else: minutes = (duration.tdelta.seconds % 3600) // 60
years, months = 0, 0 seconds = duration.tdelta.seconds % 60
elif isinstance(duration, datetime.timedelta):
days = duration.days days = duration.days
seconds = duration.total_seconds() - days * 86400 # Extract time-only part hours = duration.seconds // 3600
minutes = (duration.seconds % 3600) // 60
seconds = duration.seconds % 60
hours, seconds = divmod(seconds, 3600) # Adjust for cases where duration is a timedelta object
minutes, seconds = divmod(seconds, 60) # Convert days to weeks and days if applicable
weeks, days = divmod(days, 7)
# Build the human-readable components
components = [] components = []
if years: if years:
components.append(f"{years} Year{'s' if years > 1 else ''}") components.append(f"{years} Year{'s' if years > 1 else ''}")
if months: if months:
components.append(f"{months} Month{'s' if months > 1 else ''}") components.append(f"{months} Month{'s' if months > 1 else ''}")
if weeks:
components.append(f"{weeks} Week{'s' if weeks > 1 else ''}")
if days: if days:
components.append(f"{days} Day{'s' if days > 1 else ''}") components.append(f"{days} Day{'s' if days > 1 else ''}")
if hours: if hours:
components.append(f"{int(hours)} Hour{'s' if hours > 1 else ''}") components.append(f"{hours} Hour{'s' if hours > 1 else ''}")
if minutes: if minutes:
components.append(f"{int(minutes)} Minute{'s' if minutes > 1 else ''}") components.append(f"{minutes} Minute{'s' if minutes > 1 else ''}")
if seconds: if seconds:
components.append(f"{int(seconds)} Second{'s' if seconds > 1 else ''}") components.append(f"{seconds} Second{'s' if seconds > 1 else ''}")
# Construct the human-readable string # Construct the human-readable string
if len(components) > 1: if len(components) > 1:
@@ -581,6 +607,7 @@ def iso8601_diff_to_human(start, end):
return "" return ""
diff = iso8601_diff(start, end) diff = iso8601_diff(start, end)
return iso8601_duration_to_human(diff) return iso8601_duration_to_human(diff)
@@ -779,7 +806,11 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
lines = [] lines = []
for _line in line.split("\n"): for _line in line.split("\n"):
_line = ensure_dialog_line_format(_line) try:
_line = ensure_dialog_line_format(_line)
except Exception as exc:
log.error("ensure_dialog_format", msg="Error ensuring dialog line format", line=_line, exc_info=exc)
pass
lines.append(_line) lines.append(_line)

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel from pydantic import BaseModel
from talemate.emit import emit from talemate.emit import emit
import structlog import structlog
import traceback
from typing import Union from typing import Union
import talemate.instance as instance import talemate.instance as instance
@@ -59,7 +60,8 @@ class WorldState(BaseModel):
world_state = await self.agent.request_world_state() world_state = await self.agent.request_world_state()
except Exception as e: except Exception as e:
self.emit() self.emit()
raise e log.error("world_state.request_update", error=e, traceback=traceback.format_exc())
return
previous_characters = self.characters previous_characters = self.characters
previous_items = self.items previous_items = self.items

View File

@@ -7,11 +7,12 @@
size="14"></v-progress-circular> size="14"></v-progress-circular>
<v-icon v-else-if="agent.status === 'uninitialized'" color="orange" size="14">mdi-checkbox-blank-circle</v-icon> <v-icon v-else-if="agent.status === 'uninitialized'" color="orange" size="14">mdi-checkbox-blank-circle</v-icon>
<v-icon v-else-if="agent.status === 'disabled'" color="grey-darken-2" size="14">mdi-checkbox-blank-circle</v-icon> <v-icon v-else-if="agent.status === 'disabled'" color="grey-darken-2" size="14">mdi-checkbox-blank-circle</v-icon>
<v-icon v-else-if="agent.status === 'error'" color="red" size="14">mdi-checkbox-blank-circle</v-icon>
<v-icon v-else color="green" size="14">mdi-checkbox-blank-circle</v-icon> <v-icon v-else color="green" size="14">mdi-checkbox-blank-circle</v-icon>
<span class="ml-1" v-if="agent.label"> {{ agent.label }}</span> <span class="ml-1" v-if="agent.label"> {{ agent.label }}</span>
<span class="ml-1" v-else> {{ agent.name }}</span> <span class="ml-1" v-else> {{ agent.name }}</span>
</v-list-item-title> </v-list-item-title>
<v-list-item-subtitle> <v-list-item-subtitle class="text-caption">
{{ agent.client }} {{ agent.client }}
</v-list-item-subtitle> </v-list-item-subtitle>
<v-chip class="mr-1" v-if="agent.status === 'disabled'" size="x-small">Disabled</v-chip> <v-chip class="mr-1" v-if="agent.status === 'disabled'" size="x-small">Disabled</v-chip>
@@ -65,7 +66,10 @@ export default {
for(let i = 0; i < this.state.agents.length; i++) { for(let i = 0; i < this.state.agents.length; i++) {
let agent = this.state.agents[i]; let agent = this.state.agents[i];
if(agent.status === 'warning' || agent.status === 'error') { if(!agent.data.requires_llm_client)
continue
if(agent.status === 'warning' || agent.status === 'error' || agent.status === 'uninitialized') {
console.log("agents: configuration required (1)", agent.status) console.log("agents: configuration required (1)", agent.status)
return true; return true;
} }
@@ -99,7 +103,6 @@ export default {
} else { } else {
this.state.agents[index] = agent; this.state.agents[index] = agent;
} }
this.state.dialog = false;
this.$emit('agents-updated', this.state.agents); this.$emit('agents-updated', this.state.agents);
}, },
editAgent(index) { editAgent(index) {

View File

@@ -120,7 +120,7 @@ export default {
this.state.currentClient = { this.state.currentClient = {
name: 'TextGenWebUI', name: 'TextGenWebUI',
type: 'textgenwebui', type: 'textgenwebui',
apiUrl: 'http://localhost:5000/api', apiUrl: 'http://localhost:5000',
model_name: '', model_name: '',
max_token_length: 4096, max_token_length: 4096,
}; };

View File

@@ -10,7 +10,7 @@
</v-col> </v-col>
<v-col cols="3" class="text-right"> <v-col cols="3" class="text-right">
<v-checkbox :label="enabledLabel()" hide-details density="compact" color="green" v-model="agent.enabled" <v-checkbox :label="enabledLabel()" hide-details density="compact" color="green" v-model="agent.enabled"
v-if="agent.data.has_toggle"></v-checkbox> v-if="agent.data.has_toggle" @update:modelValue="save(false)"></v-checkbox>
</v-col> </v-col>
</v-row> </v-row>
@@ -18,7 +18,7 @@
</v-card-title> </v-card-title>
<v-card-text class="scrollable-content"> <v-card-text class="scrollable-content">
<v-select v-model="agent.client" :items="agent.data.client" label="Client"></v-select> <v-select v-if="agent.data.requires_llm_client" v-model="agent.client" :items="agent.data.client" label="Client" @update:modelValue="save(false)"></v-select>
<v-alert type="warning" variant="tonal" density="compact" v-if="agent.data.experimental"> <v-alert type="warning" variant="tonal" density="compact" v-if="agent.data.experimental">
This agent is currently experimental and may significantly decrease performance and / or require This agent is currently experimental and may significantly decrease performance and / or require
@@ -27,27 +27,25 @@
<v-card v-for="(action, key) in agent.actions" :key="key" density="compact"> <v-card v-for="(action, key) in agent.actions" :key="key" density="compact">
<v-card-subtitle> <v-card-subtitle>
<v-checkbox :label="agent.data.actions[key].label" hide-details density="compact" color="green" v-model="action.enabled"></v-checkbox> <v-checkbox v-if="!actionAlwaysEnabled(key)" :label="agent.data.actions[key].label" hide-details density="compact" color="green" v-model="action.enabled" @update:modelValue="save(false)"></v-checkbox>
</v-card-subtitle> </v-card-subtitle>
<v-card-text> <v-card-text>
{{ agent.data.actions[key].description }} <div v-if="!actionAlwaysEnabled(key)">
{{ agent.data.actions[key].description }}
</div>
<div v-for="(action_config, config_key) in agent.data.actions[key].config" :key="config_key"> <div v-for="(action_config, config_key) in agent.data.actions[key].config" :key="config_key">
<div v-if="action.enabled"> <div v-if="action.enabled">
<!-- render config widgets based on action_config.type (int, str, bool, float) --> <!-- render config widgets based on action_config.type (int, str, bool, float) -->
<v-text-field v-if="action_config.type === 'text'" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact"></v-text-field> <v-text-field v-if="action_config.type === 'text' && action_config.choices === null" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact" @update:modelValue="save(true)"></v-text-field>
<v-slider v-if="action_config.type === 'number' && action_config.step !== null" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" :min="action_config.min" :max="action_config.max" :step="action_config.step" density="compact" thumb-label></v-slider> <v-autocomplete v-else-if="action_config.type === 'text' && action_config.choices !== null" v-model="action.config[config_key].value" :items="action_config.choices" :label="action_config.label" :hint="action_config.description" density="compact" item-title="label" item-value="value" @update:modelValue="save(false)"></v-autocomplete>
<v-checkbox v-if="action_config.type === 'bool'" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact"></v-checkbox> <v-slider v-if="action_config.type === 'number' && action_config.step !== null" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" :min="action_config.min" :max="action_config.max" :step="action_config.step" density="compact" thumb-label @update:modelValue="save(true)"></v-slider>
<v-checkbox v-if="action_config.type === 'bool'" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact" @update:modelValue="save(false)"></v-checkbox>
</div> </div>
</div> </div>
</v-card-text> </v-card-text>
</v-card> </v-card>
</v-card-text> </v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="primary" @click="close">Close</v-btn>
<v-btn color="primary" @click="save">Save</v-btn>
</v-card-actions>
</v-card> </v-card>
</v-dialog> </v-dialog>
</template> </template>
@@ -58,9 +56,10 @@ export default {
dialog: Boolean, dialog: Boolean,
formTitle: String formTitle: String
}, },
inject: ['state'], inject: ['state', 'getWebsocket'],
data() { data() {
return { return {
saveTimeout: null,
localDialog: this.state.dialog, localDialog: this.state.dialog,
agent: { ...this.state.currentAgent } agent: { ...this.state.currentAgent }
}; };
@@ -90,12 +89,32 @@ export default {
return 'Disabled'; return 'Disabled';
} }
}, },
actionAlwaysEnabled(action) {
if (action.charAt(0) === '_') {
return true;
} else {
return false;
}
},
close() { close() {
this.$emit('update:dialog', false); this.$emit('update:dialog', false);
}, },
save() { save(delayed = false) {
this.$emit('save', this.agent); console.log("save", delayed);
this.close(); if(!delayed) {
this.$emit('save', this.agent);
return;
}
if(this.saveTimeout !== null)
clearTimeout(this.saveTimeout);
this.saveTimeout = setTimeout(() => {
this.$emit('save', this.agent);
}, 500);
//this.$emit('save', this.agent);
} }
} }
} }

View File

@@ -0,0 +1,96 @@
<template>
<div class="audio-queue">
<span>{{ queue.length }} sound(s) queued</span>
<v-icon :color="isPlaying ? 'green' : ''" v-if="!isMuted" @click="toggleMute">mdi-volume-high</v-icon>
<v-icon :color="isPlaying ? 'red' : ''" v-else @click="toggleMute">mdi-volume-off</v-icon>
<v-icon class="ml-1" @click="stopAndClear">mdi-stop-circle-outline</v-icon>
</div>
</template>
<script>
export default {
name: 'AudioQueue',
data() {
return {
queue: [],
audioContext: null,
isPlaying: false,
isMuted: false,
currentSource: null
};
},
inject: ['getWebsocket', 'registerMessageHandler'],
created() {
this.audioContext = new (window.AudioContext || window.webkitAudioContext)();
this.registerMessageHandler(this.handleMessage);
},
methods: {
handleMessage(data) {
if (data.type === 'audio_queue') {
console.log('Received audio queue message', data)
this.addToQueue(data.data.audio_data);
}
},
addToQueue(base64Sound) {
const soundBuffer = this.base64ToArrayBuffer(base64Sound);
this.queue.push(soundBuffer);
this.playNextSound();
},
base64ToArrayBuffer(base64) {
const binaryString = window.atob(base64);
const len = binaryString.length;
const bytes = new Uint8Array(len);
for (let i = 0; i < len; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return bytes.buffer;
},
playNextSound() {
if (this.isPlaying || this.queue.length === 0) {
return;
}
this.isPlaying = true;
const soundBuffer = this.queue.shift();
this.audioContext.decodeAudioData(soundBuffer, (buffer) => {
const source = this.audioContext.createBufferSource();
source.buffer = buffer;
this.currentSource = source;
if (!this.isMuted) {
source.connect(this.audioContext.destination);
}
source.onended = () => {
this.isPlaying = false;
this.playNextSound();
};
source.start(0);
}, (error) => {
console.error('Error with decoding audio data', error);
});
},
toggleMute() {
this.isMuted = !this.isMuted;
if (this.isMuted && this.currentSource) {
this.currentSource.disconnect(this.audioContext.destination);
} else if (this.currentSource) {
this.currentSource.connect(this.audioContext.destination);
}
},
stopAndClear() {
if (this.currentSource) {
this.currentSource.stop();
this.currentSource.disconnect();
this.currentSource = null;
}
this.queue = [];
this.isPlaying = false;
}
}
};
</script>
<style scoped>
.audio-queue {
display: flex;
align-items: center;
}
</style>

View File

@@ -8,7 +8,7 @@
<v-container> <v-container>
<v-row> <v-row>
<v-col cols="6"> <v-col cols="6">
<v-select v-model="client.type" :items="['openai', 'textgenwebui']" label="Client Type"></v-select> <v-select v-model="client.type" :disabled="!typeEditable()" :items="['openai', 'textgenwebui', 'lmstudio']" label="Client Type"></v-select>
</v-col> </v-col>
<v-col cols="6"> <v-col cols="6">
<v-text-field v-model="client.name" label="Client Name"></v-text-field> <v-text-field v-model="client.name" label="Client Name"></v-text-field>
@@ -17,13 +17,13 @@
</v-row> </v-row>
<v-row> <v-row>
<v-col cols="12"> <v-col cols="12">
<v-text-field v-model="client.apiUrl" v-if="client.type === 'textgenwebui'" label="API URL"></v-text-field> <v-text-field v-model="client.apiUrl" v-if="isLocalApiClient(client)" label="API URL"></v-text-field>
<v-select v-model="client.model" v-if="client.type === 'openai'" :items="['gpt-4-1106-preview', 'gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']" label="Model"></v-select> <v-select v-model="client.model" v-if="client.type === 'openai'" :items="['gpt-4-1106-preview', 'gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']" label="Model"></v-select>
</v-col> </v-col>
</v-row> </v-row>
<v-row> <v-row>
<v-col cols="6"> <v-col cols="6">
<v-text-field v-model="client.max_token_length" v-if="client.type === 'textgenwebui'" type="number" label="Context Length"></v-text-field> <v-text-field v-model="client.max_token_length" v-if="isLocalApiClient(client)" type="number" label="Context Length"></v-text-field>
</v-col> </v-col>
</v-row> </v-row>
</v-container> </v-container>
@@ -68,12 +68,18 @@ export default {
} }
}, },
methods: { methods: {
typeEditable() {
return this.state.formTitle === 'Add Client';
},
close() { close() {
this.$emit('update:dialog', false); this.$emit('update:dialog', false);
}, },
save() { save() {
this.$emit('save', this.client); // Emit save event with client object this.$emit('save', this.client); // Emit save event with client object
this.close(); this.close();
},
isLocalApiClient(client) {
return client.type === 'textgenwebui' || client.type === 'lmstudio';
} }
} }
} }

View File

@@ -80,6 +80,12 @@ export default {
this.getWebsocket().send(JSON.stringify({ type: 'request_scenes_list', query: this.sceneSearchInput })); this.getWebsocket().send(JSON.stringify({ type: 'request_scenes_list', query: this.sceneSearchInput }));
}, },
loadCreative() { loadCreative() {
if(this.sceneSaved === false) {
if(!confirm("The current scene is not saved. Are you sure you want to load a new scene?")) {
return;
}
}
this.loading = true; this.loading = true;
this.getWebsocket().send(JSON.stringify({ type: 'load_scene', file_path: "environment:creative" })); this.getWebsocket().send(JSON.stringify({ type: 'load_scene', file_path: "environment:creative" }));
}, },

View File

@@ -75,6 +75,8 @@
<span v-if="connecting" class="ml-1"><v-icon class="mr-1">mdi-progress-helper</v-icon>connecting</span> <span v-if="connecting" class="ml-1"><v-icon class="mr-1">mdi-progress-helper</v-icon>connecting</span>
<span v-else-if="connected" class="ml-1"><v-icon class="mr-1" color="green" size="14">mdi-checkbox-blank-circle</v-icon>connected</span> <span v-else-if="connected" class="ml-1"><v-icon class="mr-1" color="green" size="14">mdi-checkbox-blank-circle</v-icon>connected</span>
<span v-else class="ml-1"><v-icon class="mr-1">mdi-progress-close</v-icon>disconnected</span> <span v-else class="ml-1"><v-icon class="mr-1">mdi-progress-close</v-icon>disconnected</span>
<v-divider class="ml-1 mr-1" vertical></v-divider>
<AudioQueue ref="audioQueue" />
<v-spacer></v-spacer> <v-spacer></v-spacer>
<span v-if="version !== null">v{{ version }}</span> <span v-if="version !== null">v{{ version }}</span>
<span v-if="configurationRequired()"> <span v-if="configurationRequired()">
@@ -161,6 +163,7 @@ import SceneHistory from './SceneHistory.vue';
import CreativeEditor from './CreativeEditor.vue'; import CreativeEditor from './CreativeEditor.vue';
import AppConfig from './AppConfig.vue'; import AppConfig from './AppConfig.vue';
import DebugTools from './DebugTools.vue'; import DebugTools from './DebugTools.vue';
import AudioQueue from './AudioQueue.vue';
export default { export default {
components: { components: {
@@ -177,6 +180,7 @@ export default {
CreativeEditor, CreativeEditor,
AppConfig, AppConfig,
DebugTools, DebugTools,
AudioQueue,
}, },
name: 'TalemateApp', name: 'TalemateApp',
data() { data() {

View File

@@ -0,0 +1,4 @@
{{ system_message }}
### Instruction:
{{ set_response(prompt, "\n\n### Response:\n") }}

View File

@@ -0,0 +1,3 @@
USER:
{{ system_message }}
{{ set_response(prompt, "\nASSISTANT:") }}

View File

@@ -0,0 +1 @@
Human: {{ system_message }} {{ set_response(prompt, "\n\nAssistant:") }}

View File

@@ -0,0 +1,4 @@
{{ system_message }}
### Instruction:
{{ set_response(prompt, "\n\n### Response:\n") }}

View File

@@ -0,0 +1,2 @@
SYSTEM: {{ system_message }}
USER: {{ set_response(prompt, "\nASSISTANT: ") }}

View File

@@ -0,0 +1,4 @@
<|im_start|>system
{{ system_message }}<|im_end|>
<|im_start|>user
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}

View File

@@ -0,0 +1,4 @@
<|im_start|>system
{{ system_message }}<|im_end|>
<|im_start|>user
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}

View File

@@ -0,0 +1,4 @@
<|im_start|>system
{{ system_message }}<|im_end|>
<|im_start|>user
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}

View File

@@ -6,5 +6,5 @@ call talemate_env\Scripts\activate
REM use poetry to install dependencies REM use poetry to install dependencies
python -m poetry install python -m poetry install
echo Virtual environment re-created. echo Virtual environment updated
pause pause