mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-12-16 19:57:47 +01:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
496eb469db | ||
|
|
b78fec3bac | ||
|
|
d250df8950 | ||
|
|
816f950afe | ||
|
|
8fb72fdbe9 | ||
|
|
54297a4768 | ||
|
|
d7e72d27c5 | ||
|
|
f9b23f8705 | ||
|
|
37a5873330 | ||
|
|
bc3f5d63c8 |
28
README.md
28
README.md
@@ -4,22 +4,26 @@ Allows you to play roleplay scenarios with large language models.
|
||||
|
||||
It does not run any large language models itself but relies on existing APIs. Currently supports **text-generation-webui** and **openai**.
|
||||
|
||||
This means you need to either have an openai api key or know how to setup [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (locally or remotely via gpu renting. `--api` flag needs to be set)
|
||||
This means you need to either have an openai api key or know how to setup [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (locally or remotely via gpu renting. `--extension openai` flag needs to be set)
|
||||
|
||||

|
||||
As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
|
||||
|
||||

|
||||

|
||||
|
||||
## Current features
|
||||
|
||||
- responive modern ui
|
||||
- agents
|
||||
- conversation
|
||||
- narration
|
||||
- summarization
|
||||
- director
|
||||
- creative
|
||||
- conversation: handles character dialogue
|
||||
- narration: handles narrative exposition
|
||||
- summarization: handles summarization to compress context while maintain history
|
||||
- director: can be used to direct the story / characters
|
||||
- editor: improves AI responses (very hit and miss at the moment)
|
||||
- world state: generates world snapshot and handles passage of time (objects and characters)
|
||||
- creator: character / scenario creator
|
||||
- multi-client (agents can be connected to separate APIs)
|
||||
- long term memory (experimental)
|
||||
- long term memory
|
||||
- chromadb integration
|
||||
- passage of time
|
||||
- narrative world state
|
||||
@@ -36,6 +40,7 @@ Kinda making it up as i go along, but i want to lean more into gameplay through
|
||||
|
||||
In no particular order:
|
||||
|
||||
- TTS support
|
||||
- Extension support
|
||||
- modular agents and clients
|
||||
- Improved world state
|
||||
@@ -49,7 +54,7 @@ In no particular order:
|
||||
- objectives
|
||||
- quests
|
||||
- win / lose conditions
|
||||
- Automatic1111 client
|
||||
- Automatic1111 client for in place visual generation
|
||||
|
||||
# Quickstart
|
||||
|
||||
@@ -113,6 +118,8 @@ https://www.reddit.com/r/LocalLLaMA/comments/17fhp9k/huge_llm_comparisontest_39_
|
||||
|
||||
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
|
||||
|
||||
As of version 0.13.0 the legacy text-generator-webui API `--extension api` is no longer supported, please use their new `--extension openai` api implementation instead.
|
||||
|
||||

|
||||
|
||||
### Text-generation-webui
|
||||
@@ -155,7 +162,10 @@ Make sure you save the scene after the character is loaded as it can then be loa
|
||||
|
||||
## Further documentation
|
||||
|
||||
Please read the documents in the `docs` folder for more advanced configuration and usage.
|
||||
|
||||
- Creative mode (docs WIP)
|
||||
- Prompt template overrides
|
||||
- [Text-to-Speech (TTS)](docs/tts.md)
|
||||
- [ChromaDB (long term memory)](docs/chromadb.md)
|
||||
- Runpod Integration
|
||||
|
||||
@@ -14,13 +14,32 @@ game:
|
||||
gender: male
|
||||
name: Elmer
|
||||
|
||||
## Long-term memory
|
||||
|
||||
#chromadb:
|
||||
# embeddings: instructor
|
||||
# instructor_device: cuda
|
||||
# instructor_model: hkunlp/instructor-xl
|
||||
|
||||
## Remote LLMs
|
||||
|
||||
#openai:
|
||||
# api_key: <API_KEY>
|
||||
|
||||
#runpod:
|
||||
# api_key: <API_KEY>
|
||||
# api_key: <API_KEY>
|
||||
|
||||
## TTS (Text-to-Speech)
|
||||
|
||||
#elevenlabs:
|
||||
# api_key: <API_KEY>
|
||||
|
||||
#coqui:
|
||||
# api_key: <API_KEY>
|
||||
|
||||
#tts:
|
||||
# device: cuda
|
||||
# model: tts_models/multilingual/multi-dataset/xtts_v2
|
||||
# voices:
|
||||
# - label: <name>
|
||||
# value: <path to .wav for voice sample>
|
||||
BIN
docs/img/Screenshot_9.png
Normal file
BIN
docs/img/Screenshot_9.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 551 KiB |
84
docs/tts.md
Normal file
84
docs/tts.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# Talemate Text-to-Speech (TTS) Configuration
|
||||
|
||||
Talemate supports Text-to-Speech (TTS) functionality, allowing users to convert text into spoken audio. This document outlines the steps required to configure TTS for Talemate using different providers, including ElevenLabs, Coqui, and a local TTS API.
|
||||
|
||||
## Configuring ElevenLabs TTS
|
||||
|
||||
To use ElevenLabs TTS with Talemate, follow these steps:
|
||||
|
||||
1. Visit [ElevenLabs](https://elevenlabs.com) and create an account if you don't already have one.
|
||||
2. Click on your profile in the upper right corner of the Eleven Labs website to access your API key.
|
||||
3. In the `config.yaml` file, under the `elevenlabs` section, set the `api_key` field with your ElevenLabs API key.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
elevenlabs:
|
||||
api_key: <YOUR_ELEVENLABS_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Coqui TTS
|
||||
|
||||
To use Coqui TTS with Talemate, follow these steps:
|
||||
|
||||
1. Visit [Coqui](https://app.coqui.ai) and sign up for an account.
|
||||
2. Go to the [account page](https://app.coqui.ai/account) and scroll to the bottom to find your API key.
|
||||
3. In the `config.yaml` file, under the `coqui` section, set the `api_key` field with your Coqui API key.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
coqui:
|
||||
api_key: <YOUR_COQUI_API_KEY>
|
||||
```
|
||||
|
||||
## Configuring Local TTS API
|
||||
|
||||
For running a local TTS API, Talemate requires specific dependencies to be installed.
|
||||
|
||||
### Windows Installation
|
||||
|
||||
Run `install-local-tts.bat` to install the necessary requirements.
|
||||
|
||||
### Linux Installation
|
||||
|
||||
Execute the following command:
|
||||
|
||||
```bash
|
||||
pip install TTS
|
||||
```
|
||||
|
||||
### Model and Device Configuration
|
||||
|
||||
1. Choose a TTS model from the [Coqui TTS model list](https://github.com/coqui-ai/TTS).
|
||||
2. Decide whether to use `cuda` or `cpu` for the device setting.
|
||||
3. The first time you run TTS through the local API, it will download the specified model. Please note that this may take some time, and the download progress will be visible in the Talemate backend output.
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
tts:
|
||||
device: cuda # or 'cpu'
|
||||
model: tts_models/multilingual/multi-dataset/xtts_v2
|
||||
```
|
||||
|
||||
### Voice Samples Configuration
|
||||
|
||||
Configure voice samples by setting the `value` field to the path of a .wav file voice sample. Official samples can be downloaded from [Coqui XTTS-v2 samples](https://huggingface.co/coqui/XTTS-v2/tree/main/samples).
|
||||
|
||||
Example configuration snippet:
|
||||
|
||||
```yaml
|
||||
tts:
|
||||
voices:
|
||||
- label: English Male
|
||||
value: path/to/english_male.wav
|
||||
- label: English Female
|
||||
value: path/to/english_female.wav
|
||||
```
|
||||
|
||||
## Saving the Configuration
|
||||
|
||||
After configuring the `config.yaml` file, save your changes. Talemate will use the updated settings the next time it starts.
|
||||
|
||||
For more detailed information on configuring Talemate, refer to the `config.py` file in the Talemate source code and the `config.example.yaml` file for a barebone configuration example.
|
||||
4
install-local-tts.bat
Normal file
4
install-local-tts.bat
Normal file
@@ -0,0 +1,4 @@
|
||||
REM activate the virtual environment
|
||||
call talemate_env\Scripts\activate
|
||||
|
||||
call pip install "TTS>=0.21.1"
|
||||
@@ -7,10 +7,10 @@ REM activate the virtual environment
|
||||
call talemate_env\Scripts\activate
|
||||
|
||||
REM install poetry
|
||||
python -m pip install poetry "rapidfuzz>=3" -U
|
||||
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
|
||||
|
||||
REM use poetry to install dependencies
|
||||
poetry install
|
||||
python -m poetry install
|
||||
|
||||
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
|
||||
IF NOT EXIST config.yaml copy config.example.yaml config.yaml
|
||||
|
||||
1982
poetry.lock
generated
1982
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.12.0"
|
||||
version = "0.14.0"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
@@ -37,11 +37,12 @@ nest_asyncio = "^1.5.7"
|
||||
isodate = ">=0.6.1"
|
||||
thefuzz = ">=0.20.0"
|
||||
tiktoken = ">=0.5.1"
|
||||
nltk = ">=3.8.1"
|
||||
|
||||
# ChromaDB
|
||||
chromadb = ">=0.4,<1"
|
||||
chromadb = ">=0.4.17,<1"
|
||||
InstructorEmbedding = "^1.0.1"
|
||||
torch = ">=2.0.0, !=2.0.1"
|
||||
torch = ">=2.1.0"
|
||||
sentence-transformers="^2.2.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
@@ -9,7 +9,7 @@ REM activate the virtual environment
|
||||
call talemate_env\Scripts\activate
|
||||
|
||||
REM install poetry
|
||||
python -m pip install poetry "rapidfuzz>=3" -U
|
||||
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
|
||||
|
||||
REM use poetry to install dependencies
|
||||
python -m poetry install
|
||||
|
||||
@@ -2,4 +2,4 @@ from .agents import Agent
|
||||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.12.0"
|
||||
VERSION = "0.14.0"
|
||||
|
||||
@@ -8,4 +8,5 @@ from .narrator import NarratorAgent
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register
|
||||
from .summarize import SummarizeAgent
|
||||
from .editor import EditorAgent
|
||||
from .world_state import WorldStateAgent
|
||||
from .world_state import WorldStateAgent
|
||||
from .tts import TTSAgent
|
||||
@@ -23,16 +23,31 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.agents.base")
|
||||
|
||||
class CallableConfigValue:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __str__(self):
|
||||
return "CallableConfigValue"
|
||||
|
||||
def __repr__(self):
|
||||
return "CallableConfigValue"
|
||||
|
||||
class AgentActionConfig(pydantic.BaseModel):
|
||||
type: str
|
||||
label: str
|
||||
description: str = ""
|
||||
value: Union[int, float, str, bool]
|
||||
value: Union[int, float, str, bool, None]
|
||||
default_value: Union[int, float, str, bool] = None
|
||||
max: Union[int, float, None] = None
|
||||
min: Union[int, float, None] = None
|
||||
step: Union[int, float, None] = None
|
||||
scope: str = "global"
|
||||
choices: Union[list[dict[str, str]], None] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentAction(pydantic.BaseModel):
|
||||
enabled: bool = True
|
||||
@@ -40,7 +55,6 @@ class AgentAction(pydantic.BaseModel):
|
||||
description: str = ""
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
decorator that emits the agent status as processing while the function
|
||||
@@ -70,6 +84,7 @@ class Agent(ABC):
|
||||
agent_type = "agent"
|
||||
verbose_name = None
|
||||
set_processing = set_processing
|
||||
requires_llm_client = True
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
@@ -89,7 +104,7 @@ class Agent(ABC):
|
||||
if not getattr(self.client, "enabled", True):
|
||||
return False
|
||||
|
||||
if self.client.current_status in ["error", "warning"]:
|
||||
if self.client and self.client.current_status in ["error", "warning"]:
|
||||
return False
|
||||
|
||||
return self.client is not None
|
||||
@@ -135,6 +150,7 @@ class Agent(ABC):
|
||||
"enabled": agent.enabled if agent else True,
|
||||
"has_toggle": agent.has_toggle if agent else False,
|
||||
"experimental": agent.experimental if agent else False,
|
||||
"requires_llm_client": cls.requires_llm_client,
|
||||
}
|
||||
actions = getattr(agent, "actions", None)
|
||||
|
||||
|
||||
@@ -406,7 +406,7 @@ class ConversationAgent(Agent):
|
||||
|
||||
context = await memory.multi_query(history, max_tokens=500, iterate=5)
|
||||
|
||||
self.current_memory_context = "\n".join(context)
|
||||
self.current_memory_context = "\n\n".join(context)
|
||||
|
||||
return self.current_memory_context
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import talemate.emit.async_signals
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||
|
||||
from .base import Agent, set_processing, AgentAction
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
@@ -21,6 +21,7 @@ import re
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.agents.narrator import NarratorAgentEmission
|
||||
|
||||
log = structlog.get_logger("talemate.agents.editor")
|
||||
|
||||
@@ -40,7 +41,9 @@ class EditorAgent(Agent):
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"edit_dialogue": AgentAction(enabled=False, label="Edit dialogue", description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue."),
|
||||
"fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue."),
|
||||
"fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.", config={
|
||||
"narrator": AgentActionConfig(type="bool", label="Fix narrator messages", description="Will attempt to fix exposition issues in narrator messages", value=True),
|
||||
}),
|
||||
"add_detail": AgentAction(enabled=False, label="Add detail", description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.")
|
||||
}
|
||||
|
||||
@@ -59,6 +62,7 @@ class EditorAgent(Agent):
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
|
||||
talemate.emit.async_signals.get("agent.narrator.generated").connect(self.on_narrator_generated)
|
||||
|
||||
async def on_conversation_generated(self, emission:ConversationAgentEmission):
|
||||
"""
|
||||
@@ -93,6 +97,24 @@ class EditorAgent(Agent):
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
async def on_narrator_generated(self, emission:NarratorAgentEmission):
|
||||
"""
|
||||
Called when a narrator message is generated
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
log.info("editing narrator", emission=emission)
|
||||
|
||||
edited = []
|
||||
|
||||
for text in emission.generation:
|
||||
edit = await self.fix_exposition_on_narrator(text)
|
||||
edited.append(edit)
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
|
||||
@set_processing
|
||||
async def edit_conversation(self, content:str, character:Character):
|
||||
@@ -127,12 +149,13 @@ class EditorAgent(Agent):
|
||||
if not self.actions["fix_exposition"].enabled:
|
||||
return content
|
||||
|
||||
#response = await Prompt.request("editor.fix-exposition", self.client, "edit_fix_exposition", vars={
|
||||
# "content": content,
|
||||
# "character": character,
|
||||
# "scene": self.scene,
|
||||
# "max_length": self.client.max_token_length
|
||||
#})
|
||||
if not character.is_player:
|
||||
if '"' not in content and '*' not in content:
|
||||
content = util.strip_partial_sentences(content)
|
||||
character_prefix = f"{character.name}: "
|
||||
message = content.split(character_prefix)[1]
|
||||
content = f"{character_prefix}*{message.strip('*')}*"
|
||||
return content
|
||||
|
||||
content = util.clean_dialogue(content, main_name=character.name)
|
||||
content = util.strip_partial_sentences(content)
|
||||
@@ -140,6 +163,24 @@ class EditorAgent(Agent):
|
||||
|
||||
return content
|
||||
|
||||
@set_processing
|
||||
async def fix_exposition_on_narrator(self, content:str):
|
||||
|
||||
if not self.actions["fix_exposition"].enabled:
|
||||
return content
|
||||
|
||||
if not self.actions["fix_exposition"].config["narrator"].value:
|
||||
return content
|
||||
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
if '"' not in content:
|
||||
content = f"*{content.strip('*')}*"
|
||||
else:
|
||||
content = util.ensure_dialog_format(content)
|
||||
|
||||
return content
|
||||
|
||||
@set_processing
|
||||
async def add_detail(self, content:str, character:Character):
|
||||
"""
|
||||
|
||||
@@ -206,6 +206,7 @@ from .registry import register
|
||||
@register(condition=lambda: chromadb is not None)
|
||||
class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
requires_llm_client = False
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
@@ -222,7 +223,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
@property
|
||||
def agent_details(self):
|
||||
return f"ChromaDB: {self.embeddings}"
|
||||
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""
|
||||
@@ -328,9 +329,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
model_name=instructor_model, device=instructor_device
|
||||
)
|
||||
|
||||
log.info("chromadb", status="embedding function ready")
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef
|
||||
)
|
||||
|
||||
log.info("chromadb", status="instructor db ready")
|
||||
else:
|
||||
log.info("chromadb", status="using default embeddings")
|
||||
self.db = self.db_client.get_or_create_collection(collection_name)
|
||||
@@ -405,7 +410,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
id = uid or f"__narrator__-{self.memory_tracker['__narrator__']}"
|
||||
ids = [id]
|
||||
|
||||
log.debug("chromadb agent add", text=text, meta=meta, id=id)
|
||||
#log.debug("chromadb agent add", text=text, meta=meta, id=id)
|
||||
|
||||
self.db.upsert(documents=[text], metadatas=metadatas, ids=ids)
|
||||
|
||||
@@ -461,6 +466,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
#import json
|
||||
#print(json.dumps(_results["ids"], indent=2))
|
||||
#print(json.dumps(_results["distances"], indent=2))
|
||||
|
||||
results = []
|
||||
|
||||
@@ -474,9 +480,10 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
if distance < 1:
|
||||
|
||||
try:
|
||||
log.debug("chromadb agent get", ts=ts, scene_ts=self.scene.ts)
|
||||
date_prefix = util.iso8601_diff_to_human(ts, self.scene.ts)
|
||||
except Exception:
|
||||
log.error("chromadb agent", error="failed to get date prefix", ts=ts, scene_ts=self.scene.ts)
|
||||
except Exception as e:
|
||||
log.error("chromadb agent", error="failed to get date prefix", details=e, ts=ts, scene_ts=self.scene.ts)
|
||||
date_prefix = None
|
||||
|
||||
if date_prefix:
|
||||
|
||||
@@ -1,22 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
import dataclasses
|
||||
import structlog
|
||||
import random
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
import talemate.emit.async_signals
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.agents.base import set_processing, Agent, AgentAction, AgentActionConfig
|
||||
from talemate.agents.base import set_processing as _set_processing, Agent, AgentAction, AgentActionConfig, AgentEmission
|
||||
from talemate.agents.world_state import TimePassageEmission
|
||||
from talemate.scene_message import NarratorMessage
|
||||
from talemate.events import GameLoopActorIterEvent
|
||||
import talemate.client as client
|
||||
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Player, Character
|
||||
|
||||
log = structlog.get_logger("talemate.agents.narrator")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NarratorAgentEmission(AgentEmission):
|
||||
generation: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.narrator.generated"
|
||||
)
|
||||
|
||||
def set_processing(fn):
|
||||
|
||||
"""
|
||||
Custom decorator that emits the agent status as processing while the function
|
||||
is running and then emits the result of the function as a NarratorAgentEmission
|
||||
"""
|
||||
|
||||
@_set_processing
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission = NarratorAgentEmission(
|
||||
agent=self,
|
||||
generation=[response],
|
||||
)
|
||||
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
|
||||
return emission.generation[0]
|
||||
wrapper.__name__ = fn.__name__
|
||||
return wrapper
|
||||
|
||||
@register()
|
||||
class NarratorAgent(Agent):
|
||||
|
||||
"""
|
||||
Handles narration of the story
|
||||
"""
|
||||
|
||||
agent_type = "narrator"
|
||||
verbose_name = "Narrator"
|
||||
|
||||
@@ -27,31 +65,78 @@ class NarratorAgent(Agent):
|
||||
):
|
||||
self.client = client
|
||||
|
||||
# agent actions
|
||||
|
||||
self.actions = {
|
||||
"narrate_time_passage": AgentAction(enabled=False, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"),
|
||||
"narrate_time_passage": AgentAction(enabled=True, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"),
|
||||
"narrate_dialogue": AgentAction(
|
||||
enabled=True,
|
||||
label="Narrate Dialogue",
|
||||
description="Narrator will get a chance to narrate after every line of dialogue",
|
||||
config = {
|
||||
"ai_dialog": AgentActionConfig(
|
||||
type="number",
|
||||
label="AI Dialogue",
|
||||
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||
value=0.3,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
),
|
||||
"player_dialog": AgentActionConfig(
|
||||
type="number",
|
||||
label="Player Dialogue",
|
||||
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||
value=0.3,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
def clean_result(self, result):
|
||||
|
||||
"""
|
||||
Cleans the result of a narration
|
||||
"""
|
||||
|
||||
result = result.strip().strip(":").strip()
|
||||
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
|
||||
character_names = [c.name for c in self.scene.get_characters()]
|
||||
|
||||
|
||||
cleaned = []
|
||||
for line in result.split("\n"):
|
||||
if ":" in line.strip():
|
||||
break
|
||||
for character_name in character_names:
|
||||
if line.startswith(f"{character_name}:"):
|
||||
break
|
||||
cleaned.append(line)
|
||||
|
||||
return "\n".join(cleaned)
|
||||
result = "\n".join(cleaned)
|
||||
#result = util.strip_partial_sentences(result)
|
||||
return result
|
||||
|
||||
def connect(self, scene):
|
||||
|
||||
"""
|
||||
Connect to signals
|
||||
"""
|
||||
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
|
||||
|
||||
async def on_time_passage(self, event:TimePassageEmission):
|
||||
|
||||
"""
|
||||
Handles time passage narration, if enabled
|
||||
"""
|
||||
|
||||
if not self.actions["narrate_time_passage"].enabled:
|
||||
return
|
||||
|
||||
@@ -59,6 +144,31 @@ class NarratorAgent(Agent):
|
||||
narrator_message = NarratorMessage(response, source=f"narrate_time_passage:{event.duration};{event.narrative}")
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
async def on_dialog(self, event:GameLoopActorIterEvent):
|
||||
|
||||
"""
|
||||
Handles dialogue narration, if enabled
|
||||
"""
|
||||
|
||||
if not self.actions["narrate_dialogue"].enabled:
|
||||
return
|
||||
|
||||
narrate_on_ai_chance = random.random() < self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||
narrate_on_player_chance = random.random() < self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||
|
||||
log.debug("narrate on dialog", narrate_on_ai_chance=narrate_on_ai_chance, narrate_on_player_chance=narrate_on_player_chance)
|
||||
|
||||
if event.actor.character.is_player and not narrate_on_player_chance:
|
||||
return
|
||||
|
||||
if not event.actor.character.is_player and not narrate_on_ai_chance:
|
||||
return
|
||||
|
||||
response = await self.narrate_after_dialogue(event.actor.character)
|
||||
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
@set_processing
|
||||
async def narrate_scene(self):
|
||||
@@ -155,8 +265,9 @@ class NarratorAgent(Agent):
|
||||
"as_narrative": as_narrative,
|
||||
}
|
||||
)
|
||||
|
||||
log.info("narrate_query", response=response)
|
||||
response = self.clean_result(response.strip())
|
||||
log.info("narrate_query (after clean)", response=response)
|
||||
if as_narrative:
|
||||
response = f"*{response}*"
|
||||
|
||||
@@ -265,4 +376,30 @@ class NarratorAgent(Agent):
|
||||
response = self.clean_result(response.strip())
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_after_dialogue(self, character:Character):
|
||||
"""
|
||||
Narrate after a line of dialogue
|
||||
"""
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.narrate-after-dialogue",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character": character,
|
||||
"last_line": str(self.scene.history[-1])
|
||||
}
|
||||
)
|
||||
|
||||
log.info("narrate_after_dialogue", response=response)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
@@ -5,11 +5,13 @@ import traceback
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import talemate.data_objects as data_objects
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||
from talemate.events import GameLoopEvent
|
||||
|
||||
from .base import Agent, set_processing
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
@@ -34,14 +36,40 @@ class SummarizeAgent(Agent):
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
|
||||
def on_history_add(self, event):
|
||||
asyncio.ensure_future(self.build_archive(event.scene))
|
||||
|
||||
|
||||
self.actions = {
|
||||
"archive": AgentAction(
|
||||
enabled=True,
|
||||
label="Summarize to long-term memory archive",
|
||||
description="Automatically summarize scene dialogue when the number of tokens in the history exceeds a threshold. This helps keep the context history from growing too large.",
|
||||
config={
|
||||
"threshold": AgentActionConfig(
|
||||
type="number",
|
||||
label="Token Threshold",
|
||||
description="Will summarize when the number of tokens in the history exceeds this threshold",
|
||||
min=512,
|
||||
max=8192,
|
||||
step=256,
|
||||
value=1536,
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
scene.signals["history_add"].connect(self.on_history_add)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission:GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
await self.build_archive(self.scene)
|
||||
|
||||
|
||||
def clean_result(self, result):
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
@@ -53,21 +81,31 @@ class SummarizeAgent(Agent):
|
||||
return result
|
||||
|
||||
@set_processing
|
||||
async def build_archive(self, scene, token_threshold:int=1500):
|
||||
async def build_archive(self, scene):
|
||||
end = None
|
||||
|
||||
|
||||
if not self.actions["archive"].enabled:
|
||||
return
|
||||
|
||||
if not scene.archived_history:
|
||||
start = 0
|
||||
recent_entry = None
|
||||
else:
|
||||
recent_entry = scene.archived_history[-1]
|
||||
start = recent_entry.get("end", 0) + 1
|
||||
if "end" not in recent_entry:
|
||||
# permanent historical archive entry, not tied to any specific history entry
|
||||
# meaning we are still at the beginning of the scene
|
||||
start = 0
|
||||
else:
|
||||
start = recent_entry.get("end", 0)+1
|
||||
|
||||
tokens = 0
|
||||
dialogue_entries = []
|
||||
ts = "PT0S"
|
||||
time_passage_termination = False
|
||||
|
||||
token_threshold = self.actions["archive"].config["threshold"].value
|
||||
|
||||
log.debug("build_archive", start=start, recent_entry=recent_entry)
|
||||
|
||||
if recent_entry:
|
||||
@@ -75,6 +113,9 @@ class SummarizeAgent(Agent):
|
||||
|
||||
for i in range(start, len(scene.history)):
|
||||
dialogue = scene.history[i]
|
||||
|
||||
#log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
|
||||
if isinstance(dialogue, DirectorMessage):
|
||||
if i == start:
|
||||
start += 1
|
||||
@@ -131,7 +172,7 @@ class SummarizeAgent(Agent):
|
||||
break
|
||||
adjusted_dialogue.append(line)
|
||||
dialogue_entries = adjusted_dialogue
|
||||
end = start + len(dialogue_entries)
|
||||
end = start + len(dialogue_entries)-1
|
||||
|
||||
if dialogue_entries:
|
||||
summarized = await self.summarize(
|
||||
|
||||
595
src/talemate/agents/tts.py
Normal file
595
src/talemate/agents/tts.py
Normal file
@@ -0,0 +1,595 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
import asyncio
|
||||
import httpx
|
||||
import io
|
||||
import os
|
||||
import pydantic
|
||||
import nltk
|
||||
import tempfile
|
||||
import base64
|
||||
import uuid
|
||||
import functools
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
import talemate.config as config
|
||||
import talemate.emit.async_signals
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import CharacterMessage, NarratorMessage
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
|
||||
import time
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
except ImportError:
|
||||
TTS = None
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts")#
|
||||
|
||||
if not TTS:
|
||||
# TTS installation is massive and requires a lot of dependencies
|
||||
# so we don't want to require it unless the user wants to use it
|
||||
log.info("TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api")
|
||||
|
||||
nltk.download("punkt")
|
||||
|
||||
def parse_chunks(text):
|
||||
|
||||
text = text.replace("...", "__ellipsis__")
|
||||
|
||||
chunks = sent_tokenize(text)
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*","")
|
||||
if not chunk:
|
||||
continue
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
|
||||
for i, chunk in enumerate(cleaned_chunks):
|
||||
chunk = chunk.replace("__ellipsis__", "...")
|
||||
cleaned_chunks[i] = chunk
|
||||
|
||||
return cleaned_chunks
|
||||
|
||||
def rejoin_chunks(chunks:list[str], chunk_size:int=250):
|
||||
|
||||
"""
|
||||
Will combine chunks split by punctuation into a single chunk until
|
||||
max chunk size is reached
|
||||
"""
|
||||
|
||||
joined_chunks = []
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
for chunk in chunks:
|
||||
|
||||
if len(current_chunk) + len(chunk) > chunk_size:
|
||||
joined_chunks.append(current_chunk)
|
||||
current_chunk = ""
|
||||
|
||||
current_chunk += chunk
|
||||
|
||||
if current_chunk:
|
||||
joined_chunks.append(current_chunk)
|
||||
|
||||
return joined_chunks
|
||||
|
||||
|
||||
class Voice(pydantic.BaseModel):
|
||||
value:str
|
||||
label:str
|
||||
|
||||
class VoiceLibrary(pydantic.BaseModel):
|
||||
|
||||
api: str
|
||||
voices: list[Voice] = pydantic.Field(default_factory=list)
|
||||
last_synced: float = None
|
||||
|
||||
|
||||
@register()
|
||||
class TTSAgent(Agent):
|
||||
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
|
||||
agent_type = "tts"
|
||||
verbose_name = "Text to speech"
|
||||
requires_llm_client = False
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
config_options = super().config_options(agent=agent)
|
||||
|
||||
if agent:
|
||||
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
|
||||
voice.model_dump() for voice in agent.list_voices_sync()
|
||||
]
|
||||
|
||||
return config_options
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
self.is_enabled = False
|
||||
|
||||
self.voices = {
|
||||
"elevenlabs": VoiceLibrary(api="elevenlabs"),
|
||||
"coqui": VoiceLibrary(api="coqui"),
|
||||
"tts": VoiceLibrary(api="tts"),
|
||||
}
|
||||
self.config = config.load_config()
|
||||
self.playback_done_event = asyncio.Event()
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="TTS agent configuration",
|
||||
config={
|
||||
"api": AgentActionConfig(
|
||||
type="text",
|
||||
choices=[
|
||||
# TODO at local TTS support
|
||||
{"value": "tts", "label": "TTS (Local)"},
|
||||
{"value": "elevenlabs", "label": "Eleven Labs"},
|
||||
{"value": "coqui", "label": "Coqui Studio"},
|
||||
],
|
||||
value="tts",
|
||||
label="API",
|
||||
description="Which TTS API to use",
|
||||
onchange="emit",
|
||||
),
|
||||
"voice_id": AgentActionConfig(
|
||||
type="text",
|
||||
value="default",
|
||||
label="Narrator Voice",
|
||||
description="Voice ID/Name to use for TTS",
|
||||
choices=[]
|
||||
),
|
||||
"generate_for_player": AgentActionConfig(
|
||||
type="bool",
|
||||
value=False,
|
||||
label="Generate for player",
|
||||
description="Generate audio for player messages",
|
||||
),
|
||||
"generate_for_npc": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Generate for NPCs",
|
||||
description="Generate audio for NPC messages",
|
||||
),
|
||||
"generate_for_narration": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Generate for narration",
|
||||
description="Generate audio for narration messages",
|
||||
),
|
||||
"generate_chunks": AgentActionConfig(
|
||||
type="bool",
|
||||
value=True,
|
||||
label="Split generation",
|
||||
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
|
||||
)
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
self.actions["_config"].model_dump()
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def not_ready_reason(self) -> str:
|
||||
"""
|
||||
Returns a string explaining why the agent is not ready
|
||||
"""
|
||||
|
||||
if self.ready:
|
||||
return ""
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "TTS not installed"
|
||||
|
||||
elif self.requires_token and not self.token:
|
||||
return "No API token"
|
||||
|
||||
elif not self.default_voice_id:
|
||||
return "No voice selected"
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
suffix = ""
|
||||
|
||||
if not self.ready:
|
||||
suffix = f" - {self.not_ready_reason}"
|
||||
else:
|
||||
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
|
||||
|
||||
api = self.api
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api_label = api
|
||||
for choice in choices:
|
||||
if choice["value"] == api:
|
||||
api_label = choice["label"]
|
||||
break
|
||||
|
||||
return f"{api_label}{suffix}"
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self.actions["_config"].config["api"].value
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
api = self.api
|
||||
return self.config.get(api,{}).get("api_key")
|
||||
|
||||
@property
|
||||
def default_voice_id(self):
|
||||
return self.actions["_config"].config["voice_id"].value
|
||||
|
||||
@property
|
||||
def requires_token(self):
|
||||
return self.api != "tts"
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return False
|
||||
return True
|
||||
|
||||
return (not self.requires_token or self.token) and self.default_voice_id
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if not self.enabled:
|
||||
return "disabled"
|
||||
if self.ready:
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
if self.requires_token and not self.token:
|
||||
return "error"
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "error"
|
||||
|
||||
@property
|
||||
def max_generation_length(self):
|
||||
if self.api == "elevenlabs":
|
||||
return 1024
|
||||
elif self.api == "coqui":
|
||||
return 250
|
||||
|
||||
return 250
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
|
||||
try:
|
||||
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
|
||||
except KeyError:
|
||||
api = self.api
|
||||
|
||||
api_changed = api != self.api
|
||||
|
||||
log.debug("apply_config", api=api, api_changed=api != self.api, current_api=self.api)
|
||||
|
||||
super().apply_config(*args, **kwargs)
|
||||
|
||||
|
||||
if api_changed:
|
||||
try:
|
||||
self.actions["_config"].config["voice_id"].value = self.voices[api].voices[0].value
|
||||
except IndexError:
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop_new_message").connect(self.on_game_loop_new_message)
|
||||
|
||||
async def on_game_loop_new_message(self, emission:GameLoopNewMessageEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, NarratorMessage) and not self.actions["_config"].config["generate_for_narration"].value:
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
|
||||
if emission.message.source == "player" and not self.actions["_config"].config["generate_for_player"].value:
|
||||
return
|
||||
elif emission.message.source == "ai" and not self.actions["_config"].config["generate_for_npc"].value:
|
||||
return
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
character_prefix = emission.message.split(":", 1)[0]
|
||||
else:
|
||||
character_prefix = ""
|
||||
|
||||
log.info("reactive tts", message=emission.message, character_prefix=character_prefix)
|
||||
|
||||
await self.generate(str(emission.message).replace(character_prefix+": ", ""))
|
||||
|
||||
|
||||
def voice(self, voice_id:str) -> Union[Voice, None]:
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice
|
||||
return None
|
||||
|
||||
def voice_id_to_label(self, voice_id:str):
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice.label
|
||||
return None
|
||||
|
||||
def list_voices_sync(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(self.list_voices())
|
||||
|
||||
async def list_voices(self):
|
||||
if self.requires_token and not self.token:
|
||||
return []
|
||||
|
||||
library = self.voices[self.api]
|
||||
|
||||
log.info("Listing voices", api=self.api, last_synced=library.last_synced)
|
||||
|
||||
# TODO: allow re-syncing voices
|
||||
if library.last_synced:
|
||||
return library.voices
|
||||
|
||||
list_fn = getattr(self, f"_list_voices_{self.api}")
|
||||
log.info("Listing voices", api=self.api)
|
||||
library.voices = await list_fn()
|
||||
library.last_synced = time.time()
|
||||
|
||||
# if the current voice cannot be found, reset it
|
||||
if not self.voice(self.default_voice_id):
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
# set loading to false
|
||||
return library.voices
|
||||
|
||||
@set_processing
|
||||
async def generate(self, text: str):
|
||||
if not self.enabled or not self.ready or not text:
|
||||
return
|
||||
|
||||
|
||||
self.playback_done_event.set()
|
||||
|
||||
generate_fn = getattr(self, f"_generate_{self.api}")
|
||||
|
||||
if self.actions["_config"].config["generate_chunks"].value:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks)
|
||||
else:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks, chunk_size=self.max_generation_length)
|
||||
|
||||
# Start generating audio chunks in the background
|
||||
generation_task = asyncio.create_task(self.generate_chunks(generate_fn, chunks))
|
||||
|
||||
# Wait for both tasks to complete
|
||||
await asyncio.gather(generation_task)
|
||||
|
||||
async def generate_chunks(self, generate_fn, chunks):
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*","").strip()
|
||||
log.info("Generating audio", api=self.api, chunk=chunk)
|
||||
audio_data = await generate_fn(chunk)
|
||||
self.play_audio(audio_data)
|
||||
|
||||
def play_audio(self, audio_data):
|
||||
# play audio through the python audio player
|
||||
#play(audio_data)
|
||||
|
||||
emit("audio_queue", data={"audio_data": base64.b64encode(audio_data).decode("utf-8")})
|
||||
|
||||
self.playback_done_event.set() # Signal that playback is finished
|
||||
|
||||
# LOCAL
|
||||
|
||||
async def _generate_tts(self, text: str) -> Union[bytes, None]:
|
||||
|
||||
if not TTS:
|
||||
return
|
||||
|
||||
tts_config = self.config.get("tts",{})
|
||||
model = tts_config.get("model")
|
||||
device = tts_config.get("device", "cpu")
|
||||
|
||||
log.debug("tts local", model=model, device=device)
|
||||
|
||||
if not hasattr(self, "tts_instance"):
|
||||
self.tts_instance = TTS(model).to(device)
|
||||
|
||||
tts = self.tts_instance
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
voice = self.voice(self.default_voice_id)
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(None, functools.partial(tts.tts_to_file, text=text, speaker_wav=voice.value, language="en", file_path=file_path))
|
||||
#tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
|
||||
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def _list_voices_tts(self) -> dict[str, str]:
|
||||
return [Voice(**voice) for voice in self.config.get("tts",{}).get("voices",[])]
|
||||
|
||||
# ELEVENLABS
|
||||
|
||||
async def _generate_elevenlabs(self, text: str, chunk_size: int = 1024) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.default_voice_id}"
|
||||
headers = {
|
||||
"Accept": "audio/mpeg",
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": api_key,
|
||||
}
|
||||
data = {
|
||||
"text": text,
|
||||
"model_id": "eleven_monolingual_v1",
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.5
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
|
||||
if response.status_code == 200:
|
||||
bytes_io = io.BytesIO()
|
||||
for chunk in response.iter_bytes(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
bytes_io.write(chunk)
|
||||
|
||||
# Put the audio data in the queue for playback
|
||||
return bytes_io.getvalue()
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
async def _list_voices_elevenlabs(self) -> dict[str, str]:
|
||||
|
||||
url_voices = "https://api.elevenlabs.io/v1/voices"
|
||||
|
||||
voices = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"xi-api-key": self.token,
|
||||
}
|
||||
response = await client.get(url_voices, headers=headers, params={"per_page":1000})
|
||||
speakers = response.json()["voices"]
|
||||
voices.extend([Voice(value=speaker["voice_id"], label=speaker["name"]) for speaker in speakers])
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
|
||||
# COQUI STUDIO
|
||||
|
||||
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = "https://app.coqui.ai/api/v2/samples/xtts/render/"
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
data = {
|
||||
"voice_id": self.default_voice_id,
|
||||
"text": text,
|
||||
"language": "en" # Assuming English language for simplicity; this could be parameterized
|
||||
}
|
||||
|
||||
# Make the POST request to Coqui API
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
if response.status_code in [200, 201]:
|
||||
# Parse the JSON response to get the audio URL
|
||||
response_data = response.json()
|
||||
audio_url = response_data.get('audio_url')
|
||||
if audio_url:
|
||||
# Make a GET request to download the audio file
|
||||
audio_response = await client.get(audio_url)
|
||||
if audio_response.status_code == 200:
|
||||
# delete the sample from Coqui Studio
|
||||
# await self._cleanup_coqui(response_data.get('id'))
|
||||
return audio_response.content
|
||||
else:
|
||||
log.error(f"Error downloading audio: {audio_response.text}")
|
||||
else:
|
||||
log.error("No audio URL in response")
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
async def _cleanup_coqui(self, sample_id: str):
|
||||
api_key = self.token
|
||||
if not api_key or not sample_id:
|
||||
return
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
# Make the DELETE request to Coqui API
|
||||
response = await client.delete(url, headers=headers)
|
||||
|
||||
if response.status_code == 204:
|
||||
log.info(f"Successfully deleted sample with ID: {sample_id}")
|
||||
else:
|
||||
log.error(f"Error deleting sample with ID: {sample_id}: {response.text}")
|
||||
|
||||
async def _list_voices_coqui(self) -> dict[str, str]:
|
||||
|
||||
url_speakers = "https://app.coqui.ai/api/v2/speakers"
|
||||
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
|
||||
|
||||
voices = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}"
|
||||
}
|
||||
response = await client.get(url_speakers, headers=headers, params={"per_page":1000})
|
||||
speakers = response.json()["result"]
|
||||
voices.extend([Voice(value=speaker["id"], label=speaker["name"]) for speaker in speakers])
|
||||
|
||||
response = await client.get(url_custom_voices, headers=headers, params={"per_page":1000})
|
||||
custom_voices = response.json()["result"]
|
||||
voices.extend([Voice(value=voice["id"], label=voice["name"]) for voice in custom_voices])
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
@@ -8,6 +8,7 @@ import talemate.util as util
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopEvent
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
|
||||
from .registry import register
|
||||
@@ -16,9 +17,6 @@ import structlog
|
||||
import isodate
|
||||
import time
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
|
||||
|
||||
log = structlog.get_logger("talemate.agents.world_state")
|
||||
|
||||
@@ -74,7 +72,7 @@ class WorldStateAgent(Agent):
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
async def advance_time(self, duration:str, narrative:str=None):
|
||||
"""
|
||||
@@ -96,7 +94,7 @@ class WorldStateAgent(Agent):
|
||||
)
|
||||
|
||||
|
||||
async def on_conversation_generated(self, emission:ConversationAgentEmission):
|
||||
async def on_game_loop(self, emission:GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
@@ -104,8 +102,7 @@ class WorldStateAgent(Agent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
for _ in emission.generation:
|
||||
await self.update_world_state()
|
||||
await self.update_world_state()
|
||||
|
||||
|
||||
async def update_world_state(self):
|
||||
@@ -230,7 +227,7 @@ class WorldStateAgent(Agent):
|
||||
):
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-and-follow-instruction",
|
||||
"world_state.analyze-text-and-follow-instruction",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
import talemate.client.runpod
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
import talemate.client.runpod
|
||||
|
||||
349
src/talemate/client/base.py
Normal file
349
src/talemate/client/base.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
A unified client base, based on the openai API
|
||||
"""
|
||||
import copy
|
||||
import random
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import structlog
|
||||
import logging
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.emit import emit
|
||||
import talemate.instance as instance
|
||||
import talemate.client.presets as presets
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
import talemate.util as util
|
||||
from talemate.client.context import client_context_attribute
|
||||
from talemate.client.model_prompts import model_prompt
|
||||
|
||||
|
||||
# Set up logging level for httpx to WARNING to suppress debug logs.
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
|
||||
REMOTE_SERVICES = [
|
||||
# TODO: runpod.py should add this to the list
|
||||
".runpod.net"
|
||||
]
|
||||
|
||||
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
|
||||
|
||||
class ClientBase:
|
||||
|
||||
api_url: str
|
||||
model_name: str
|
||||
name:str = None
|
||||
enabled: bool = True
|
||||
current_status: str = None
|
||||
max_token_length: int = 4096
|
||||
randomizable_inference_parameters: list[str] = ["temperature"]
|
||||
processing: bool = False
|
||||
connected: bool = False
|
||||
conversation_retries: int = 5
|
||||
|
||||
client_type = "base"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str = None,
|
||||
name = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.api_url = api_url
|
||||
self.name = name or self.client_type
|
||||
self.log = structlog.get_logger(f"client.{self.client_type}")
|
||||
self.set_client()
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
|
||||
|
||||
def set_client(self):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
|
||||
|
||||
def prompt_template(self, sys_msg, prompt):
|
||||
|
||||
"""
|
||||
Applies the appropriate prompt template for the model.
|
||||
"""
|
||||
|
||||
if not self.model_name:
|
||||
self.log.warning("prompt template not applied", reason="no model loaded")
|
||||
return f"{sys_msg}\n{prompt}"
|
||||
|
||||
return model_prompt(self.model_name, sys_msg, prompt)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
|
||||
"""
|
||||
Reconfigures the client.
|
||||
|
||||
Keyword Arguments:
|
||||
|
||||
- api_url: the API URL to use
|
||||
- max_token_length: the max token length to use
|
||||
- enabled: whether the client is enabled
|
||||
"""
|
||||
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
|
||||
def toggle_disabled_if_remote(self):
|
||||
|
||||
"""
|
||||
If the client is targeting a remote recognized service, this
|
||||
will disable the client.
|
||||
"""
|
||||
|
||||
for service in REMOTE_SERVICES:
|
||||
if service in self.api_url:
|
||||
if self.enabled:
|
||||
self.log.warn("remote service unreachable, disabling client", client=self.name)
|
||||
self.enabled = False
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_system_message(self, kind: str) -> str:
|
||||
|
||||
"""
|
||||
Returns the appropriate system message for the given kind of generation
|
||||
|
||||
Arguments:
|
||||
|
||||
- kind: the kind of generation
|
||||
"""
|
||||
|
||||
# TODO: make extensible
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST
|
||||
|
||||
return system_prompts.BASIC
|
||||
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
|
||||
"""
|
||||
Sets and emits the client status.
|
||||
"""
|
||||
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if not self.enabled:
|
||||
status = "disabled"
|
||||
model_name = "Disabled"
|
||||
elif not self.connected:
|
||||
status = "error"
|
||||
model_name = "Could not connect"
|
||||
elif self.model_name:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
model_name = "No model loaded"
|
||||
status = "warning"
|
||||
|
||||
status_change = status != self.current_status
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
)
|
||||
|
||||
if status_change:
|
||||
instance.emit_agent_status_by_client(self)
|
||||
|
||||
|
||||
async def get_model_name(self):
|
||||
models = await self.client.models.list()
|
||||
try:
|
||||
return models.data[0].id
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
async def status(self):
|
||||
"""
|
||||
Send a request to the API to retrieve the loaded AI model name.
|
||||
Raises an error if no model name is returned.
|
||||
:return: None
|
||||
"""
|
||||
if self.processing:
|
||||
return
|
||||
|
||||
if not self.enabled:
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
try:
|
||||
self.model_name = await self.get_model_name()
|
||||
except Exception as e:
|
||||
self.log.warning("client status error", e=e, client=self.name)
|
||||
self.model_name = None
|
||||
self.connected = False
|
||||
self.toggle_disabled_if_remote()
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
self.connected = True
|
||||
|
||||
if not self.model_name or self.model_name == "None":
|
||||
self.log.warning("client model not loaded", client=self)
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
self.emit_status()
|
||||
|
||||
|
||||
def generate_prompt_parameters(self, kind:str):
|
||||
parameters = {}
|
||||
self.tune_prompt_parameters(
|
||||
presets.configure(parameters, kind, self.max_token_length),
|
||||
kind
|
||||
)
|
||||
return parameters
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
parameters["stream"] = False
|
||||
if client_context_attribute("nuke_repetition") > 0.0 and self.jiggle_enabled_for(kind):
|
||||
self.jiggle_randomness(parameters, offset=client_context_attribute("nuke_repetition"))
|
||||
|
||||
fn_tune_kind = getattr(self, f"tune_prompt_parameters_{kind}", None)
|
||||
if fn_tune_kind:
|
||||
fn_tune_kind(parameters)
|
||||
|
||||
def tune_prompt_parameters_conversation(self, parameters:dict):
|
||||
conversation_context = client_context_attribute("conversation")
|
||||
parameters["max_tokens"] = conversation_context.get("length", 96)
|
||||
|
||||
dialog_stopping_strings = [
|
||||
f"{character}:" for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
if "extra_stopping_strings" in parameters:
|
||||
parameters["extra_stopping_strings"] += dialog_stopping_strings
|
||||
else:
|
||||
parameters["extra_stopping_strings"] = dialog_stopping_strings
|
||||
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.completions.create(prompt=prompt.strip(), **parameters)
|
||||
return response.get("choices", [{}])[0].get("text", "")
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
|
||||
async def send_prompt(
|
||||
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the AI and return its response.
|
||||
:param prompt: The text prompt to send.
|
||||
:return: The AI's response text.
|
||||
"""
|
||||
|
||||
try:
|
||||
self.emit_status(processing=True)
|
||||
await self.status()
|
||||
|
||||
prompt_param = self.generate_prompt_parameters(kind)
|
||||
|
||||
finalized_prompt = self.prompt_template(self.get_system_message(kind), prompt).strip()
|
||||
prompt_param = finalize(prompt_param)
|
||||
|
||||
token_length = self.count_tokens(finalized_prompt)
|
||||
|
||||
|
||||
time_start = time.time()
|
||||
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
|
||||
|
||||
self.log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length, parameters=prompt_param)
|
||||
response = await self.generate(finalized_prompt, prompt_param, kind)
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
# stopping strings sometimes get appended to the end of the response anyways
|
||||
# split the response by the first stopping string and take the first part
|
||||
|
||||
|
||||
for stopping_string in STOPPING_STRINGS + extra_stopping_strings:
|
||||
if stopping_string in response:
|
||||
response = response.split(stopping_string)[0]
|
||||
break
|
||||
|
||||
emit("prompt_sent", data={
|
||||
"kind": kind,
|
||||
"prompt": finalized_prompt,
|
||||
"response": response,
|
||||
"prompt_tokens": token_length,
|
||||
"response_tokens": self.count_tokens(response),
|
||||
"time": time_end - time_start,
|
||||
})
|
||||
|
||||
return response
|
||||
finally:
|
||||
self.emit_status(processing=False)
|
||||
|
||||
def count_tokens(self, content:str):
|
||||
return util.count_tokens(content)
|
||||
|
||||
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
min_offset = offset * 0.3
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
|
||||
def jiggle_enabled_for(self, kind:str):
|
||||
|
||||
if kind in ["conversation", "story"]:
|
||||
return True
|
||||
|
||||
if kind.startswith("narrate"):
|
||||
return True
|
||||
|
||||
return False
|
||||
56
src/talemate/client/lmstudio.py
Normal file
56
src/talemate/client/lmstudio.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
@register()
|
||||
class LMStudioClient(ClientBase):
|
||||
|
||||
client_type = "lmstudio"
|
||||
conversation_retries = 5
|
||||
|
||||
def set_client(self):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
|
||||
async def get_model_name(self):
|
||||
model_name = await super().get_model_name()
|
||||
|
||||
# model name comes back as a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
if model_name:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.config import load_config
|
||||
@@ -15,10 +14,9 @@ import tiktoken
|
||||
__all__ = [
|
||||
"OpenAIClient",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
|
||||
def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@@ -70,7 +68,7 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
|
||||
return num_tokens
|
||||
|
||||
@register()
|
||||
class OpenAIClient:
|
||||
class OpenAIClient(ClientBase):
|
||||
"""
|
||||
OpenAI client for generating text.
|
||||
"""
|
||||
@@ -79,13 +77,10 @@ class OpenAIClient:
|
||||
conversation_retries = 0
|
||||
|
||||
def __init__(self, model="gpt-4-1106-preview", **kwargs):
|
||||
self.name = kwargs.get("name", "openai")
|
||||
|
||||
self.model_name = model
|
||||
self.last_token_length = 0
|
||||
self.max_token_length = 2048
|
||||
self.processing = False
|
||||
self.current_status = "idle"
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# if os.environ.get("OPENAI_API_KEY") is not set, look in the config file
|
||||
# and set it
|
||||
@@ -94,7 +89,7 @@ class OpenAIClient:
|
||||
if self.config.get("openai", {}).get("api_key"):
|
||||
os.environ["OPENAI_API_KEY"] = self.config["openai"]["api_key"]
|
||||
|
||||
self.set_client(model)
|
||||
self.set_client()
|
||||
|
||||
|
||||
@property
|
||||
@@ -123,12 +118,14 @@ class OpenAIClient:
|
||||
status=status,
|
||||
)
|
||||
|
||||
def set_client(self, model:str, max_token_length:int=None):
|
||||
def set_client(self, max_token_length:int=None):
|
||||
|
||||
if not self.openai_api_key:
|
||||
log.error("No OpenAI API key set")
|
||||
return
|
||||
|
||||
model = self.model_name
|
||||
|
||||
self.client = AsyncOpenAI()
|
||||
if model == "gpt-3.5-turbo":
|
||||
self.max_token_length = min(max_token_length or 4096, 4096)
|
||||
@@ -144,89 +141,72 @@ class OpenAIClient:
|
||||
def reconfigure(self, **kwargs):
|
||||
if "model" in kwargs:
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(self.model_name, kwargs.get("max_token_length"))
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
def count_tokens(self, content: str):
|
||||
return num_tokens_from_messages([{"content": content}], model=self.model_name)
|
||||
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
def get_system_message(self, kind: str) -> str:
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST
|
||||
|
||||
return system_prompts.BASIC
|
||||
|
||||
async def send_prompt(
|
||||
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
|
||||
) -> str:
|
||||
|
||||
right = ""
|
||||
opts = {}
|
||||
|
||||
|
||||
def prompt_template(self, system_message:str, prompt:str):
|
||||
# only gpt-4-1106-preview supports json_object response coersion
|
||||
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
||||
expected_response = prompt.split("\nContinue this response: ")[1].strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
opts["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
self.emit_status(processing=True)
|
||||
await asyncio.sleep(0.1)
|
||||
return prompt
|
||||
|
||||
sys_message = {'role': 'system', 'content': self.get_system_message(kind)}
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
human_message = {'role': 'user', 'content': prompt}
|
||||
keys = list(parameters.keys())
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
log.debug("openai send", kind=kind, sys_message=sys_message, opts=opts)
|
||||
|
||||
time_start = time.time()
|
||||
|
||||
response = await self.client.chat.completions.create(model=self.model_name, messages=[sys_message, human_message], **opts)
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
time_end = time.time()
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
response = response.choices[0].message.content
|
||||
# only gpt-4-1106-preview supports json_object response coersion
|
||||
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
|
||||
right = None
|
||||
try:
|
||||
_, right = prompt.split("\nContinue this response: ")
|
||||
expected_response = right.strip()
|
||||
if expected_response.startswith("{") and supports_json_object:
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right):].strip()
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
system_message = {'role': 'system', 'content': self.get_system_message(kind)}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[system_message, human_message], **parameters
|
||||
)
|
||||
|
||||
if kind == "conversation":
|
||||
response = response.replace("\n", " ").strip()
|
||||
|
||||
log.debug("openai response", response=response)
|
||||
|
||||
emit("prompt_sent", data={
|
||||
"kind": kind,
|
||||
"prompt": prompt,
|
||||
"response": response,
|
||||
"prompt_tokens": num_tokens_from_messages([sys_message, human_message], model=self.model_name),
|
||||
"response_tokens": num_tokens_from_messages([{"role": "assistant", "content": response}], model=self.model_name),
|
||||
"time": time_end - time_start,
|
||||
})
|
||||
|
||||
self.emit_status(processing=False)
|
||||
return response
|
||||
response = response.choices[0].message.content
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right):].strip()
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
163
src/talemate/client/presets.py
Normal file
163
src/talemate/client/presets.py
Normal file
@@ -0,0 +1,163 @@
|
||||
__all__ = [
|
||||
"configure",
|
||||
"set_max_tokens",
|
||||
"set_preset",
|
||||
"preset_for_kind",
|
||||
"max_tokens_for_kind",
|
||||
"PRESET_TALEMATE_CONVERSATION",
|
||||
"PRESET_TALEMATE_CREATOR",
|
||||
"PRESET_LLAMA_PRECISE",
|
||||
"PRESET_DIVINE_INTELLECT",
|
||||
"PRESET_SIMPLE_1",
|
||||
]
|
||||
|
||||
PRESET_TALEMATE_CONVERSATION = {
|
||||
"temperature": 0.65,
|
||||
"top_p": 0.47,
|
||||
"top_k": 42,
|
||||
"repetition_penalty": 1.18,
|
||||
"repetition_penalty_range": 2048,
|
||||
}
|
||||
|
||||
PRESET_TALEMATE_CREATOR = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 20,
|
||||
"repetition_penalty": 1.15,
|
||||
"repetition_penalty_range": 512,
|
||||
}
|
||||
|
||||
PRESET_LLAMA_PRECISE = {
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'top_k': 40,
|
||||
'repetition_penalty': 1.18,
|
||||
}
|
||||
|
||||
PRESET_DIVINE_INTELLECT = {
|
||||
'temperature': 1.31,
|
||||
'top_p': 0.14,
|
||||
'top_k': 49,
|
||||
"repetition_penalty_range": 1024,
|
||||
'repetition_penalty': 1.17,
|
||||
}
|
||||
|
||||
PRESET_SIMPLE_1 = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 20,
|
||||
"repetition_penalty": 1.15,
|
||||
}
|
||||
|
||||
def configure(config:dict, kind:str, total_budget:int):
|
||||
"""
|
||||
Sets the config based on the kind of text to generate.
|
||||
"""
|
||||
set_preset(config, kind)
|
||||
set_max_tokens(config, kind, total_budget)
|
||||
return config
|
||||
|
||||
def set_max_tokens(config:dict, kind:str, total_budget:int):
|
||||
"""
|
||||
Sets the max_tokens in the config based on the kind of text to generate.
|
||||
"""
|
||||
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
|
||||
return config
|
||||
|
||||
def set_preset(config:dict, kind:str):
|
||||
"""
|
||||
Sets the preset in the config based on the kind of text to generate.
|
||||
"""
|
||||
config.update(preset_for_kind(kind))
|
||||
|
||||
def preset_for_kind(kind: str):
|
||||
if kind == "conversation":
|
||||
return PRESET_TALEMATE_CONVERSATION
|
||||
elif kind == "conversation_old":
|
||||
return PRESET_TALEMATE_CONVERSATION # Assuming old conversation uses the same preset
|
||||
elif kind == "conversation_long":
|
||||
return PRESET_TALEMATE_CONVERSATION # Assuming long conversation uses the same preset
|
||||
elif kind == "conversation_select_talking_actor":
|
||||
return PRESET_TALEMATE_CONVERSATION # Assuming select talking actor uses the same preset
|
||||
elif kind == "summarize":
|
||||
return PRESET_LLAMA_PRECISE
|
||||
elif kind == "analyze":
|
||||
return PRESET_SIMPLE_1
|
||||
elif kind == "analyze_creative":
|
||||
return PRESET_DIVINE_INTELLECT
|
||||
elif kind == "analyze_long":
|
||||
return PRESET_SIMPLE_1 # Assuming long analysis uses the same preset as simple
|
||||
elif kind == "analyze_freeform":
|
||||
return PRESET_LLAMA_PRECISE
|
||||
elif kind == "analyze_freeform_short":
|
||||
return PRESET_LLAMA_PRECISE # Assuming short freeform analysis uses the same preset as precise
|
||||
elif kind == "narrate":
|
||||
return PRESET_LLAMA_PRECISE
|
||||
elif kind == "story":
|
||||
return PRESET_DIVINE_INTELLECT
|
||||
elif kind == "create":
|
||||
return PRESET_TALEMATE_CREATOR
|
||||
elif kind == "create_concise":
|
||||
return PRESET_TALEMATE_CREATOR # Assuming concise creation uses the same preset as creator
|
||||
elif kind == "create_precise":
|
||||
return PRESET_LLAMA_PRECISE
|
||||
elif kind == "director":
|
||||
return PRESET_SIMPLE_1
|
||||
elif kind == "director_short":
|
||||
return PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
|
||||
elif kind == "director_yesno":
|
||||
return PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
|
||||
elif kind == "edit_dialogue":
|
||||
return PRESET_DIVINE_INTELLECT
|
||||
elif kind == "edit_add_detail":
|
||||
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
|
||||
elif kind == "edit_fix_exposition":
|
||||
return PRESET_DIVINE_INTELLECT # Assuming fixing exposition uses the same preset as divine intellect
|
||||
else:
|
||||
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
|
||||
|
||||
def max_tokens_for_kind(kind: str, total_budget: int):
|
||||
if kind == "conversation":
|
||||
return 75 # Example value, adjust as needed
|
||||
elif kind == "conversation_old":
|
||||
return 75 # Example value, adjust as needed
|
||||
elif kind == "conversation_long":
|
||||
return 300 # Example value, adjust as needed
|
||||
elif kind == "conversation_select_talking_actor":
|
||||
return 30 # Example value, adjust as needed
|
||||
elif kind == "summarize":
|
||||
return 500 # Example value, adjust as needed
|
||||
elif kind == "analyze":
|
||||
return 500 # Example value, adjust as needed
|
||||
elif kind == "analyze_creative":
|
||||
return 1024 # Example value, adjust as needed
|
||||
elif kind == "analyze_long":
|
||||
return 2048 # Example value, adjust as needed
|
||||
elif kind == "analyze_freeform":
|
||||
return 500 # Example value, adjust as needed
|
||||
elif kind == "analyze_freeform_short":
|
||||
return 10 # Example value, adjust as needed
|
||||
elif kind == "narrate":
|
||||
return 500 # Example value, adjust as needed
|
||||
elif kind == "story":
|
||||
return 300 # Example value, adjust as needed
|
||||
elif kind == "create":
|
||||
return min(1024, int(total_budget * 0.35)) # Example calculation, adjust as needed
|
||||
elif kind == "create_concise":
|
||||
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||
elif kind == "create_precise":
|
||||
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||
elif kind == "director":
|
||||
return min(600, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||
elif kind == "director_short":
|
||||
return 25 # Example value, adjust as needed
|
||||
elif kind == "director_yesno":
|
||||
return 2 # Example value, adjust as needed
|
||||
elif kind == "edit_dialogue":
|
||||
return 100 # Example value, adjust as needed
|
||||
elif kind == "edit_add_detail":
|
||||
return 200 # Example value, adjust as needed
|
||||
elif kind == "edit_fix_exposition":
|
||||
return 1024 # Example value, adjust as needed
|
||||
else:
|
||||
return 150 # Default value if none of the kinds match
|
||||
@@ -67,9 +67,9 @@ def _client_bootstrap(client_type: ClientType, pod):
|
||||
id = pod["id"]
|
||||
|
||||
if client_type == ClientType.textgen:
|
||||
api_url = f"https://{id}-5000.proxy.runpod.net/api"
|
||||
api_url = f"https://{id}-5000.proxy.runpod.net"
|
||||
elif client_type == ClientType.automatic1111:
|
||||
api_url = f"https://{id}-5000.proxy.runpod.net/api"
|
||||
api_url = f"https://{id}-5000.proxy.runpod.net"
|
||||
|
||||
return ClientBootstrap(
|
||||
client_type=client_type,
|
||||
|
||||
@@ -1,735 +1,65 @@
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
import copy
|
||||
import structlog
|
||||
import time
|
||||
import httpx
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Union
|
||||
import logging
|
||||
import talemate.util as util
|
||||
from talemate.client.base import ClientBase, STOPPING_STRINGS
|
||||
from talemate.client.registry import register
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
from talemate.emit import Emission, emit
|
||||
from talemate.client.context import client_context_attribute
|
||||
from talemate.client.model_prompts import model_prompt
|
||||
|
||||
import talemate.instance as instance
|
||||
|
||||
log = structlog.get_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"TaleMateClient",
|
||||
"RestApiTaleMateClient",
|
||||
"TextGeneratorWebuiClient",
|
||||
]
|
||||
|
||||
# Set up logging level for httpx to WARNING to suppress debug logs.
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
|
||||
class DefaultContext(int):
|
||||
pass
|
||||
|
||||
|
||||
PRESET_TALEMATE_LEGACY = {
|
||||
"temperature": 0.72,
|
||||
"top_p": 0.73,
|
||||
"top_k": 0,
|
||||
"top_a": 0,
|
||||
"repetition_penalty": 1.18,
|
||||
"repetition_penalty_range": 2048,
|
||||
"encoder_repetition_penalty": 1,
|
||||
#"encoder_repetition_penalty": 1.2,
|
||||
#"no_repeat_ngram_size": 2,
|
||||
"do_sample": True,
|
||||
"length_penalty": 1,
|
||||
}
|
||||
|
||||
PRESET_TALEMATE_CONVERSATION = {
|
||||
"temperature": 0.65,
|
||||
"top_p": 0.47,
|
||||
"top_k": 42,
|
||||
"typical_p": 1,
|
||||
"top_a": 0,
|
||||
"tfs": 1,
|
||||
"epsilon_cutoff": 0,
|
||||
"eta_cutoff": 0,
|
||||
"repetition_penalty": 1.18,
|
||||
"repetition_penalty_range": 2048,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"penalty_alpha": 0,
|
||||
"num_beams": 1,
|
||||
"length_penalty": 1,
|
||||
"min_length": 0,
|
||||
"encoder_rep_pen": 1,
|
||||
"do_sample": True,
|
||||
"early_stopping": False,
|
||||
"mirostat_mode": 0,
|
||||
"mirostat_tau": 5,
|
||||
"mirostat_eta": 0.1
|
||||
}
|
||||
|
||||
PRESET_TALEMATE_CREATOR = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.15,
|
||||
"repetition_penalty_range": 512,
|
||||
"top_k": 20,
|
||||
"do_sample": True,
|
||||
"length_penalty": 1,
|
||||
}
|
||||
|
||||
PRESET_LLAMA_PRECISE = {
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'repetition_penalty': 1.18,
|
||||
'top_k': 40
|
||||
}
|
||||
|
||||
PRESET_KOBOLD_GODLIKE = {
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.5,
|
||||
'typical_p': 0.19,
|
||||
'repetition_penalty': 1.1,
|
||||
"repetition_penalty_range": 1024,
|
||||
}
|
||||
|
||||
PRESET_DIVINE_INTELLECT = {
|
||||
'temperature': 1.31,
|
||||
'top_p': 0.14,
|
||||
"repetition_penalty_range": 1024,
|
||||
'repetition_penalty': 1.17,
|
||||
'top_k': 49,
|
||||
"mirostat_mode": 0,
|
||||
"mirostat_tau": 5,
|
||||
"mirostat_eta": 0.1,
|
||||
"tfs": 1,
|
||||
}
|
||||
|
||||
PRESET_SIMPLE_1 = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.15,
|
||||
"top_k": 20,
|
||||
}
|
||||
|
||||
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config["repetition_penalty"]
|
||||
|
||||
copied_config = copy.deepcopy(prompt_config)
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
||||
|
||||
return copied_config
|
||||
|
||||
|
||||
class TaleMateClient:
|
||||
"""
|
||||
An abstract TaleMate client that can be implemented for different communication methods with the AI.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str,
|
||||
max_token_length: Union[int, DefaultContext] = int.__new__(
|
||||
DefaultContext, 2048
|
||||
),
|
||||
):
|
||||
self.api_url = api_url
|
||||
self.name = "generic_client"
|
||||
self.model_name = None
|
||||
self.last_token_length = 0
|
||||
self.max_token_length = max_token_length
|
||||
self.original_max_token_length = max_token_length
|
||||
self.enabled = True
|
||||
self.current_status = None
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: dict) -> str:
|
||||
"""
|
||||
Sends a message to the AI. Needs to be implemented by the subclass.
|
||||
:param message: The message to be sent.
|
||||
:return: The AI's response text.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send_prompt(self, prompt: str) -> str:
|
||||
"""
|
||||
Sends a prompt to the AI. Needs to be implemented by the subclass.
|
||||
:param prompt: The text prompt to send.
|
||||
:return: The AI's response text.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
def remaining_tokens(self, context: Union[str, list]) -> int:
|
||||
return self.max_token_length - util.count_tokens(context)
|
||||
|
||||
|
||||
def prompt_template(self, sys_msg, prompt):
|
||||
return model_prompt(self.model_name, sys_msg, prompt)
|
||||
|
||||
class RESTTaleMateClient(TaleMateClient, ABC):
|
||||
"""
|
||||
A RESTful TaleMate client that connects to the REST API endpoint.
|
||||
"""
|
||||
|
||||
async def send_message(self, message: dict, url: str) -> str:
|
||||
"""
|
||||
Sends a message to the REST API and returns the AI's response.
|
||||
:param message: The message to be sent.
|
||||
:return: The AI's response text.
|
||||
"""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=message, timeout=None)
|
||||
response_data = response.json()
|
||||
return response_data["results"][0]["text"]
|
||||
except KeyError:
|
||||
return response_data["results"][0]["history"]["visible"][0][-1]
|
||||
from openai import AsyncOpenAI
|
||||
import httpx
|
||||
import copy
|
||||
import random
|
||||
|
||||
|
||||
@register()
|
||||
class TextGeneratorWebuiClient(RESTTaleMateClient):
|
||||
"""
|
||||
Client that connects to the text-generatior-webui api
|
||||
"""
|
||||
|
||||
class TextGeneratorWebuiClient(ClientBase):
|
||||
|
||||
client_type = "textgenwebui"
|
||||
conversation_retries = 5
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
|
||||
# is this needed?
|
||||
parameters["max_new_tokens"] = parameters["max_tokens"]
|
||||
|
||||
def __init__(self, api_url: str, max_token_length: int = 2048, **kwargs):
|
||||
def set_client(self):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
|
||||
|
||||
async def get_model_name(self):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{self.api_url}/v1/internal/model/info", timeout=2)
|
||||
if response.status_code == 404:
|
||||
raise Exception("Could not find model info (wrong api version?)")
|
||||
response_data = response.json()
|
||||
model_name = response_data.get("model_name")
|
||||
|
||||
api_url = self.cleanup_api_url(api_url)
|
||||
if model_name == "None":
|
||||
model_name = None
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
self.api_url_base = api_url
|
||||
api_url = f"{api_url}/v1/chat"
|
||||
super().__init__(api_url, max_token_length=max_token_length)
|
||||
self.model_name = None
|
||||
self.limited_ram = False
|
||||
self.name = kwargs.get("name", "textgenwebui")
|
||||
self.processing = False
|
||||
self.connected = False
|
||||
|
||||
def __str__(self):
|
||||
return f"TextGeneratorWebuiClient[{self.api_url_base}][{self.model_name or ''}]"
|
||||
|
||||
def cleanup_api_url(self, api_url:str):
|
||||
"""
|
||||
Strips trailing / and ensures endpoint is /api
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if api_url.endswith("/"):
|
||||
api_url = api_url[:-1]
|
||||
|
||||
if not api_url.endswith("/api"):
|
||||
api_url = api_url + "/api"
|
||||
|
||||
return api_url
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
super().reconfigure(**kwargs)
|
||||
if "api_url" in kwargs:
|
||||
log.debug("reconfigure", api_url=kwargs["api_url"])
|
||||
api_url = kwargs["api_url"]
|
||||
api_url = self.cleanup_api_url(api_url)
|
||||
self.api_url_base = api_url
|
||||
self.api_url = api_url
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
def toggle_disabled_if_remote(self):
|
||||
parameters["prompt"] = prompt.strip()
|
||||
|
||||
remote_servies = [
|
||||
".runpod.net"
|
||||
]
|
||||
|
||||
for service in remote_servies:
|
||||
if service in self.api_url_base:
|
||||
self.enabled = False
|
||||
return
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
if not self.enabled:
|
||||
status = "disabled"
|
||||
model_name = "Disabled"
|
||||
elif not self.connected:
|
||||
status = "error"
|
||||
model_name = "Could not connect"
|
||||
elif self.model_name:
|
||||
status = "busy" if self.processing else "idle"
|
||||
model_name = self.model_name
|
||||
else:
|
||||
model_name = "No model loaded"
|
||||
status = "warning"
|
||||
|
||||
status_change = status != self.current_status
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
id=self.name,
|
||||
details=model_name,
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
if status_change:
|
||||
instance.emit_agent_status_by_client(self)
|
||||
|
||||
|
||||
# Add the 'status' method
|
||||
async def status(self):
|
||||
"""
|
||||
Send a request to the API to retrieve the loaded AI model name.
|
||||
Raises an error if no model name is returned.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{self.api_url_base}/v1/model", timeout=2)
|
||||
|
||||
except (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
self.model_name = None
|
||||
self.connected = False
|
||||
self.toggle_disabled_if_remote()
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
self.connected = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(f"{self.api_url}/v1/completions", json=parameters, timeout=None, headers=headers)
|
||||
response_data = response.json()
|
||||
self.enabled = True
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
self.connected = False
|
||||
self.toggle_disabled_if_remote()
|
||||
if not self.enabled:
|
||||
log.warn("remote service unreachable, disabling client", name=self.name)
|
||||
else:
|
||||
log.error("client response error", name=self.name, e=e)
|
||||
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
model_name = response_data.get("result")
|
||||
|
||||
if not model_name or model_name == "None":
|
||||
log.warning("client model not loaded", client=self.name)
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
model_changed = model_name != self.model_name
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
if model_changed:
|
||||
self.auto_context_length()
|
||||
|
||||
log.info(f"{self} [{self.max_token_length} ctx]: ready")
|
||||
self.emit_status()
|
||||
|
||||
def auto_context_length(self):
|
||||
return response_data["choices"][0]["text"]
|
||||
|
||||
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
|
||||
"""
|
||||
Automaticalle sets context length based on LLM
|
||||
"""
|
||||
|
||||
if not isinstance(self.max_token_length, DefaultContext):
|
||||
# context length was specified manually
|
||||
return
|
||||
|
||||
model_name = self.model_name.lower()
|
||||
|
||||
if "longchat" in model_name:
|
||||
self.max_token_length = 16000
|
||||
elif "8k" in model_name:
|
||||
if not self.limited_ram or "13b" in model_name:
|
||||
self.max_token_length = 6000
|
||||
else:
|
||||
self.max_token_length = 4096
|
||||
elif "4k" in model_name:
|
||||
self.max_token_length = 4096
|
||||
else:
|
||||
self.max_token_length = self.original_max_token_length
|
||||
|
||||
@property
|
||||
def instruction_template(self):
|
||||
if "vicuna" in self.model_name.lower():
|
||||
return "Vicuna-v1.1"
|
||||
if "camel" in self.model_name.lower():
|
||||
return "Vicuna-v1.1"
|
||||
return ""
|
||||
|
||||
def prompt_url(self):
|
||||
return self.api_url_base + "/v1/generate"
|
||||
|
||||
def prompt_config_conversation_old(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.BASIC,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 75,
|
||||
"truncation_length": self.max_token_length,
|
||||
}
|
||||
config.update(PRESET_TALEMATE_CONVERSATION)
|
||||
return config
|
||||
|
||||
|
||||
def prompt_config_conversation(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.ROLEPLAY,
|
||||
prompt,
|
||||
)
|
||||
|
||||
stopping_strings = ["<|end_of_turn|>"]
|
||||
|
||||
conversation_context = client_context_attribute("conversation")
|
||||
|
||||
stopping_strings += [
|
||||
f"{character}:" for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
max_new_tokens = conversation_context.get("length", 96)
|
||||
log.debug("prompt_config_conversation", stopping_strings=stopping_strings, conversation_context=conversation_context, max_new_tokens=max_new_tokens)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"truncation_length": self.max_token_length,
|
||||
"stopping_strings": stopping_strings,
|
||||
}
|
||||
config.update(PRESET_TALEMATE_CONVERSATION)
|
||||
|
||||
jiggle_randomness(config)
|
||||
|
||||
return config
|
||||
|
||||
def prompt_config_conversation_long(self, prompt: str) -> dict:
|
||||
config = self.prompt_config_conversation(prompt)
|
||||
config["max_new_tokens"] = 300
|
||||
return config
|
||||
|
||||
def prompt_config_conversation_select_talking_actor(self, prompt: str) -> dict:
|
||||
config = self.prompt_config_conversation(prompt)
|
||||
config["max_new_tokens"] = 30
|
||||
config["stopping_strings"] += [":"]
|
||||
return config
|
||||
|
||||
|
||||
def prompt_config_summarize(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.NARRATOR,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 500,
|
||||
"truncation_length": self.max_token_length,
|
||||
}
|
||||
|
||||
config.update(PRESET_LLAMA_PRECISE)
|
||||
return config
|
||||
|
||||
def prompt_config_analyze(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.ANALYST,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 500,
|
||||
"truncation_length": self.max_token_length,
|
||||
}
|
||||
|
||||
config.update(PRESET_SIMPLE_1)
|
||||
return config
|
||||
|
||||
def prompt_config_analyze_creative(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.ANALYST,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {}
|
||||
config.update(PRESET_DIVINE_INTELLECT)
|
||||
config.update({
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 1024,
|
||||
"repetition_penalty_range": 1024,
|
||||
"truncation_length": self.max_token_length
|
||||
})
|
||||
|
||||
return config
|
||||
|
||||
def prompt_config_analyze_long(self, prompt: str) -> dict:
|
||||
config = self.prompt_config_analyze(prompt)
|
||||
config["max_new_tokens"] = 2048
|
||||
return config
|
||||
|
||||
def prompt_config_analyze_freeform(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.ANALYST_FREEFORM,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 500,
|
||||
"truncation_length": self.max_token_length,
|
||||
}
|
||||
|
||||
config.update(PRESET_LLAMA_PRECISE)
|
||||
return config
|
||||
|
||||
|
||||
def prompt_config_analyze_freeform_short(self, prompt: str) -> dict:
|
||||
config = self.prompt_config_analyze_freeform(prompt)
|
||||
config["max_new_tokens"] = 10
|
||||
return config
|
||||
|
||||
def prompt_config_narrate(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.NARRATOR,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 500,
|
||||
"truncation_length": self.max_token_length,
|
||||
}
|
||||
config.update(PRESET_LLAMA_PRECISE)
|
||||
return config
|
||||
|
||||
def prompt_config_story(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.NARRATOR,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 300,
|
||||
"seed": random.randint(0, 1000000000),
|
||||
"truncation_length": self.max_token_length
|
||||
}
|
||||
config.update(PRESET_DIVINE_INTELLECT)
|
||||
config.update({
|
||||
"repetition_penalty": 1.3,
|
||||
"repetition_penalty_range": 2048,
|
||||
})
|
||||
return config
|
||||
|
||||
def prompt_config_create(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.CREATOR,
|
||||
prompt,
|
||||
)
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": min(1024, self.max_token_length * 0.35),
|
||||
"truncation_length": self.max_token_length,
|
||||
}
|
||||
config.update(PRESET_TALEMATE_CREATOR)
|
||||
return config
|
||||
|
||||
def prompt_config_create_concise(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.CREATOR,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": min(400, self.max_token_length * 0.25),
|
||||
"truncation_length": self.max_token_length,
|
||||
"stopping_strings": ["<|DONE|>", "\n\n"]
|
||||
}
|
||||
config.update(PRESET_TALEMATE_CREATOR)
|
||||
return config
|
||||
|
||||
def prompt_config_create_precise(self, prompt: str) -> dict:
|
||||
config = self.prompt_config_create_concise(prompt)
|
||||
config.update(PRESET_LLAMA_PRECISE)
|
||||
return config
|
||||
|
||||
def prompt_config_director(self, prompt: str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.DIRECTOR,
|
||||
prompt,
|
||||
)
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": min(600, self.max_token_length * 0.25),
|
||||
"truncation_length": self.max_token_length,
|
||||
}
|
||||
config.update(PRESET_SIMPLE_1)
|
||||
return config
|
||||
|
||||
|
||||
def prompt_config_director_short(self, prompt: str) -> dict:
|
||||
config = self.prompt_config_director(prompt)
|
||||
config.update(max_new_tokens=25)
|
||||
return config
|
||||
|
||||
def prompt_config_director_yesno(self, prompt: str) -> dict:
|
||||
config = self.prompt_config_director(prompt)
|
||||
config.update(max_new_tokens=2)
|
||||
return config
|
||||
|
||||
def prompt_config_edit_dialogue(self, prompt:str) -> dict:
|
||||
prompt = self.prompt_template(
|
||||
system_prompts.EDITOR,
|
||||
prompt,
|
||||
)
|
||||
|
||||
conversation_context = client_context_attribute("conversation")
|
||||
|
||||
stopping_strings = [
|
||||
f"{character}:" for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
config = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 100,
|
||||
"truncation_length": self.max_token_length,
|
||||
"stopping_strings": stopping_strings,
|
||||
}
|
||||
|
||||
config.update(PRESET_DIVINE_INTELLECT)
|
||||
|
||||
return config
|
||||
|
||||
def prompt_config_edit_add_detail(self, prompt:str) -> dict:
|
||||
|
||||
config = self.prompt_config_edit_dialogue(prompt)
|
||||
config.update(max_new_tokens=200)
|
||||
return config
|
||||
|
||||
|
||||
def prompt_config_edit_fix_exposition(self, prompt:str) -> dict:
|
||||
|
||||
config = self.prompt_config_edit_dialogue(prompt)
|
||||
config.update(max_new_tokens=1024)
|
||||
return config
|
||||
|
||||
|
||||
async def send_prompt(
|
||||
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the AI and return its response.
|
||||
:param prompt: The text prompt to send.
|
||||
:return: The AI's response text.
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
#prompt = prompt.replace("<|BOT|>", "<|BOT|>Certainly! ")
|
||||
|
||||
await self.status()
|
||||
self.emit_status(processing=True)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
fn_prompt_config = getattr(self, f"prompt_config_{kind}")
|
||||
fn_url = self.prompt_url
|
||||
message = fn_prompt_config(prompt)
|
||||
|
||||
if client_context_attribute("nuke_repetition") > 0.0 and kind in ["conversation", "story"]:
|
||||
log.info("nuke repetition", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
|
||||
message = jiggle_randomness(message, offset=client_context_attribute("nuke_repetition"))
|
||||
log.info("nuke repetition (applied)", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config["repetition_penalty"]
|
||||
|
||||
message = finalize(message)
|
||||
min_offset = offset * 0.3
|
||||
|
||||
token_length = int(len(message["prompt"]) / 3.6)
|
||||
|
||||
self.last_token_length = token_length
|
||||
|
||||
log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length)
|
||||
|
||||
message["prompt"] = message["prompt"].strip()
|
||||
|
||||
#print(f"prompt: |{message['prompt']}|")
|
||||
|
||||
# add <|im_end|> to stopping strings
|
||||
if "stopping_strings" in message:
|
||||
message["stopping_strings"] += ["<|im_end|>", "</s>"]
|
||||
else:
|
||||
message["stopping_strings"] = ["<|im_end|>", "</s>"]
|
||||
|
||||
#message["seed"] = -1
|
||||
|
||||
#for k,v in message.items():
|
||||
# if k == "prompt":
|
||||
# continue
|
||||
# print(f"{k}: {v}")
|
||||
|
||||
time_start = time.time()
|
||||
|
||||
response = await self.send_message(message, fn_url())
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
response = response.split("#")[0]
|
||||
self.emit_status(processing=False)
|
||||
|
||||
emit("prompt_sent", data={
|
||||
"kind": kind,
|
||||
"prompt": message["prompt"],
|
||||
"response": response,
|
||||
"prompt_tokens": token_length,
|
||||
"response_tokens": int(len(response) / 3.6),
|
||||
"time": time_end - time_start,
|
||||
})
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class OpenAPIClient(RESTTaleMateClient):
|
||||
pass
|
||||
|
||||
|
||||
class GPT3Client(OpenAPIClient):
|
||||
pass
|
||||
|
||||
|
||||
class GPT4Client(OpenAPIClient):
|
||||
pass
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
prompt_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
||||
32
src/talemate/client/utils.py
Normal file
32
src/talemate/client/utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config["repetition_penalty"]
|
||||
|
||||
copied_config = copy.deepcopy(prompt_config)
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
||||
|
||||
return copied_config
|
||||
|
||||
|
||||
def jiggle_enabled_for(kind:str):
|
||||
|
||||
if kind in ["conversation", "story"]:
|
||||
return True
|
||||
|
||||
if kind.startswith("narrate"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -23,6 +23,7 @@ from .cmd_save_as import CmdSaveAs
|
||||
from .cmd_save_characters import CmdSaveCharacters
|
||||
from .cmd_setenv import CmdSetEnvironmentToScene, CmdSetEnvironmentToCreative
|
||||
from .cmd_time_util import *
|
||||
from .cmd_tts import *
|
||||
from .cmd_world_state import CmdWorldState
|
||||
from .cmd_run_helios_test import CmdHeliosTest
|
||||
from .manager import Manager
|
||||
@@ -32,4 +32,5 @@ class CmdRebuildArchive(TalemateCommand):
|
||||
if not more:
|
||||
break
|
||||
|
||||
self.scene.sync_time()
|
||||
await self.scene.commit_to_memory()
|
||||
|
||||
@@ -17,7 +17,26 @@ class CmdRename(TalemateCommand):
|
||||
aliases = []
|
||||
|
||||
async def run(self):
|
||||
# collect list of characters in the scene
|
||||
|
||||
if self.args:
|
||||
character_name = self.args[0]
|
||||
else:
|
||||
character_names = self.scene.character_names
|
||||
character_name = await wait_for_input("Which character do you want to rename?", data={
|
||||
"input_type": "select",
|
||||
"choices": character_names,
|
||||
})
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
if not character:
|
||||
self.system_message(f"Character {character_name} not found")
|
||||
return True
|
||||
|
||||
name = await wait_for_input("Enter new name: ")
|
||||
|
||||
self.scene.main_character.character.rename(name)
|
||||
character.rename(name)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return True
|
||||
|
||||
33
src/talemate/commands/cmd_tts.py
Normal file
33
src/talemate/commands/cmd_tts.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.prompts.base import set_default_sectioning_handler
|
||||
from talemate.instance import get_agent
|
||||
|
||||
__all__ = [
|
||||
"CmdTestTTS",
|
||||
]
|
||||
|
||||
@register
|
||||
class CmdTestTTS(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'test_tts' command
|
||||
"""
|
||||
|
||||
name = "test_tts"
|
||||
description = "Test the TTS agent"
|
||||
aliases = []
|
||||
|
||||
async def run(self):
|
||||
tts_agent = get_agent("tts")
|
||||
|
||||
try:
|
||||
last_message = str(self.scene.history[-1])
|
||||
except IndexError:
|
||||
last_message = "Welcome to talemate!"
|
||||
|
||||
|
||||
await tts_agent.generate(last_message)
|
||||
|
||||
@@ -65,6 +65,21 @@ class OpenAIConfig(BaseModel):
|
||||
|
||||
class RunPodConfig(BaseModel):
|
||||
api_key: Union[str,None]=None
|
||||
|
||||
class ElevenLabsConfig(BaseModel):
|
||||
api_key: Union[str,None]=None
|
||||
|
||||
class CoquiConfig(BaseModel):
|
||||
api_key: Union[str,None]=None
|
||||
|
||||
class TTSVoiceSamples(BaseModel):
|
||||
label:str
|
||||
value:str
|
||||
|
||||
class TTSConfig(BaseModel):
|
||||
device:str = "cuda"
|
||||
model:str = "tts_models/multilingual/multi-dataset/xtts_v2"
|
||||
voices: list[TTSVoiceSamples] = pydantic.Field(default_factory=list)
|
||||
|
||||
class ChromaDB(BaseModel):
|
||||
instructor_device: str="cpu"
|
||||
@@ -85,6 +100,12 @@ class Config(BaseModel):
|
||||
|
||||
chromadb: ChromaDB = ChromaDB()
|
||||
|
||||
elevenlabs: ElevenLabsConfig = ElevenLabsConfig()
|
||||
|
||||
coqui: CoquiConfig = CoquiConfig()
|
||||
|
||||
tts: TTSConfig = TTSConfig()
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ CommandStatus = signal("command_status")
|
||||
WorldState = signal("world_state")
|
||||
ArchivedHistory = signal("archived_history")
|
||||
|
||||
AudioQueue = signal("audio_queue")
|
||||
|
||||
MessageEdited = signal("message_edited")
|
||||
|
||||
handlers = {
|
||||
@@ -46,4 +48,5 @@ handlers = {
|
||||
"archived_history": ArchivedHistory,
|
||||
"message_edited": MessageEdited,
|
||||
"prompt_sent": PromptSent,
|
||||
"audio_queue": AudioQueue,
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
from talemate.tale_mate import Scene, Actor, SceneMessage
|
||||
|
||||
__all__ = [
|
||||
"Event",
|
||||
@@ -42,4 +42,12 @@ class GameLoopEvent(Event):
|
||||
|
||||
@dataclass
|
||||
class GameLoopStartEvent(GameLoopEvent):
|
||||
pass
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class GameLoopActorIterEvent(GameLoopEvent):
|
||||
actor: Actor
|
||||
|
||||
@dataclass
|
||||
class GameLoopNewMessageEvent(GameLoopEvent):
|
||||
message: SceneMessage
|
||||
@@ -190,8 +190,11 @@ async def load_scene_from_data(
|
||||
await scene.add_actor(actor)
|
||||
|
||||
if scene.environment != "creative":
|
||||
await scene.world_state.request_update(initial_only=True)
|
||||
|
||||
try:
|
||||
await scene.world_state.request_update(initial_only=True)
|
||||
except Exception as e:
|
||||
log.error("world_state.request_update", error=e)
|
||||
|
||||
# the scene has been saved before (since we just loaded it), so we set the saved flag to True
|
||||
# as long as the scene has a memory_id.
|
||||
scene.saved = "memory_id" in scene_data
|
||||
|
||||
@@ -290,6 +290,7 @@ class Prompt:
|
||||
env.globals["query_scene"] = self.query_scene
|
||||
env.globals["query_memory"] = self.query_memory
|
||||
env.globals["query_text"] = self.query_text
|
||||
env.globals["instruct_text"] = self.instruct_text
|
||||
env.globals["retrieve_memories"] = self.retrieve_memories
|
||||
env.globals["uuidgen"] = lambda: str(uuid.uuid4())
|
||||
env.globals["to_int"] = lambda x: int(x)
|
||||
@@ -394,9 +395,14 @@ class Prompt:
|
||||
f"Answer: " + loop.run_until_complete(memory.query(query, **kwargs)),
|
||||
])
|
||||
else:
|
||||
return loop.run_until_complete(memory.multi_query([query], **kwargs))
|
||||
|
||||
return loop.run_until_complete(memory.multi_query(query.split("\n"), **kwargs))
|
||||
|
||||
def instruct_text(self, instruction:str, text:str):
|
||||
loop = asyncio.get_event_loop()
|
||||
world_state = instance.get_agent("world_state")
|
||||
instruction = instruction.format(**self.vars)
|
||||
|
||||
return loop.run_until_complete(world_state.analyze_and_follow_instruction(text, instruction))
|
||||
|
||||
def retrieve_memories(self, lines:list[str], goal:str=None):
|
||||
|
||||
@@ -467,8 +473,6 @@ class Prompt:
|
||||
|
||||
# remove all duplicate whitespace
|
||||
cleaned = re.sub(r"\s+", " ", cleaned)
|
||||
print("set_json_response", cleaned)
|
||||
|
||||
return self.set_prepared_response(cleaned)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
{% block rendered_context -%}
|
||||
<|SECTION:CONTEXT|>
|
||||
Content Context: This is a specific scene from {{ scene.context }}
|
||||
Scenario Premise: {{ scene.description }}
|
||||
{% for memory in query_memory(last_line, as_question_answer=False, iterate=10) -%}
|
||||
{{ memory }}
|
||||
|
||||
{% endfor %}
|
||||
{% endblock -%}
|
||||
<|CLOSE_SECTION|>
|
||||
{% for scene_context in scene.context_history(budget=max_tokens-200-count_tokens(self.rendered_context())) -%}
|
||||
{{ scene_context }}
|
||||
{% endfor %}
|
||||
<|SECTION:TASK|>
|
||||
Based on the previous line '{{ last_line }}', create the next line of narration. This line should focus solely on describing sensory details (like sounds, sights, smells, tactile sensations) or external actions that move the story forward. Avoid including any character's internal thoughts, feelings, or dialogue. Your narration should directly respond to '{{ last_line }}', either by elaborating on the immediate scene or by subtly advancing the plot. Generate exactly one sentence of new narration. If the character is trying to determine some state, truth or situation, try to answer as part of the narration.
|
||||
|
||||
Be creative and generate something new and interesting, but stay true to the setting and context of the story so far.
|
||||
<|CLOSE_SECTION|>
|
||||
{{ set_prepared_response('*') }}
|
||||
@@ -8,13 +8,13 @@
|
||||
{% if query.endswith("?") -%}
|
||||
Question: {{ query }}
|
||||
Extra context: {{ query_memory(query, as_question_answer=False) }}
|
||||
Instruction: Analyze Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue.
|
||||
Instruction: Analyze Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue.
|
||||
{% else -%}
|
||||
Instruction: {{ query }}
|
||||
Extra context: {{ query_memory(query, as_question_answer=False) }}
|
||||
Answer based on Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context.
|
||||
Answer based on Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context.
|
||||
{% endif -%}
|
||||
Content Context: This is a specific scene from {{ scene.context }}
|
||||
Narration style: point and click adventure game from the 90s
|
||||
Your answer should be in the style of short narration that fits the context of the scene.
|
||||
<|CLOSE_SECTION|>
|
||||
Narrator answers: {% if at_the_end %}{{ bot_token }}At the end of the dialogue, {% endif %}
|
||||
@@ -8,9 +8,10 @@
|
||||
<|SECTION:TASK|>
|
||||
Answer the following questions:
|
||||
|
||||
{{ query_text("What are 1 to 3 questions to ask the narrator of the story to gather more context from the past for the continuation of this conversation? If a character is asking about a status, location or information about an item or another character, make sure to include question(s) that help gather context for this. Don't explain your reasoning. Don't ask the actors directly.", text, as_question_answer=False) }}
|
||||
{{ instruct_text("Ask the narrator three (3) questions to gather more context from the past for the continuation of this conversation. If a character is asking about a state, location or information about an item or another character, make sure to include question(s) that help gather context for this.", text) }}
|
||||
|
||||
You answers should be precise, truthful and short.
|
||||
You answers should be precise, truthful and short. Pay close attention to timestamps when retrieving information from the context.
|
||||
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:RELEVANT CONTEXT|>
|
||||
<|SECTION:RELEVANT CONTEXT|>
|
||||
{{ bot_token }}Answers:
|
||||
@@ -0,0 +1,5 @@
|
||||
|
||||
{{ text }}
|
||||
|
||||
<|SECTION:TASK|>
|
||||
{{ instruction }}
|
||||
@@ -34,7 +34,7 @@ No dialogue so far
|
||||
{% endif -%}
|
||||
<|CLOSE_SECTION|>
|
||||
<|SECTION:SCENE PROGRESS|>
|
||||
{% for scene_context in scene.context_history(budget=300, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%}
|
||||
{% for scene_context in scene.context_history(budget=500, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%}
|
||||
{{ scene_context }}
|
||||
{% endfor -%}
|
||||
<|CLOSE_SECTION|>
|
||||
|
||||
@@ -110,7 +110,6 @@ async def websocket_endpoint(websocket, path):
|
||||
elif action_type == "request_scenes_list":
|
||||
query = data.get("query", "")
|
||||
handler.request_scenes_list(query)
|
||||
|
||||
elif action_type == "configure_clients":
|
||||
handler.configure_clients(data.get("clients"))
|
||||
elif action_type == "configure_agents":
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
26
src/talemate/server/tts.py
Normal file
26
src/talemate/server/tts.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import structlog
|
||||
|
||||
import talemate.instance as instance
|
||||
|
||||
log = structlog.get_logger("talemate.server.tts")
|
||||
|
||||
class TTSPlugin:
|
||||
router = "tts"
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
self.tts = None
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
action = data.get("action")
|
||||
|
||||
|
||||
if action == "test":
|
||||
return await self.handle_test(data)
|
||||
|
||||
async def handle_test(self, data:dict):
|
||||
|
||||
tts_agent = instance.get_agent("tts")
|
||||
|
||||
await tts_agent.generate("Welcome to talemate!")
|
||||
@@ -91,7 +91,7 @@ class WebsocketHandler(Receiver):
|
||||
for agent_typ, agent_config in self.agents.items():
|
||||
try:
|
||||
client = self.llm_clients.get(agent_config.get("client"))["client"]
|
||||
except TypeError:
|
||||
except TypeError as e:
|
||||
client = None
|
||||
|
||||
if not client:
|
||||
@@ -167,19 +167,28 @@ class WebsocketHandler(Receiver):
|
||||
log.info("Configuring clients", clients=clients)
|
||||
|
||||
for client in clients:
|
||||
if client["type"] == "textgenwebui":
|
||||
|
||||
client.pop("status", None)
|
||||
|
||||
if client["type"] in ["textgenwebui", "lmstudio"]:
|
||||
try:
|
||||
max_token_length = int(client.get("max_token_length", 2048))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
client.pop("model", None)
|
||||
|
||||
self.llm_clients[client["name"]] = {
|
||||
"type": "textgenwebui",
|
||||
"type": client["type"],
|
||||
"api_url": client["apiUrl"],
|
||||
"name": client["name"],
|
||||
"max_token_length": max_token_length,
|
||||
}
|
||||
elif client["type"] == "openai":
|
||||
|
||||
client.pop("model_name", None)
|
||||
client.pop("apiUrl", None)
|
||||
|
||||
self.llm_clients[client["name"]] = {
|
||||
"type": "openai",
|
||||
"name": client["name"],
|
||||
@@ -213,16 +222,25 @@ class WebsocketHandler(Receiver):
|
||||
def configure_agents(self, agents):
|
||||
self.agents = {typ: {} for typ in instance.agent_types()}
|
||||
|
||||
log.debug("Configuring agents", agents=agents)
|
||||
log.debug("Configuring agents")
|
||||
|
||||
for agent in agents:
|
||||
name = agent["name"]
|
||||
|
||||
# special case for memory agent
|
||||
if name == "memory":
|
||||
if name == "memory" or name == "tts":
|
||||
self.agents[name] = {
|
||||
"name": name,
|
||||
}
|
||||
agent_instance = instance.get_agent(name, **self.agents[name])
|
||||
if agent_instance.has_toggle:
|
||||
self.agents[name]["enabled"] = agent["enabled"]
|
||||
|
||||
if getattr(agent_instance, "actions", None):
|
||||
self.agents[name]["actions"] = agent.get("actions", {})
|
||||
|
||||
agent_instance.apply_config(**self.agents[name])
|
||||
log.debug("Configured agent", name=name)
|
||||
continue
|
||||
|
||||
if name not in self.agents:
|
||||
@@ -385,7 +403,7 @@ class WebsocketHandler(Receiver):
|
||||
"status": emission.status,
|
||||
"data": emission.data,
|
||||
"max_token_length": client.max_token_length if client else 2048,
|
||||
"apiUrl": getattr(client, "api_url_base", None) if client else None,
|
||||
"apiUrl": getattr(client, "api_url", None) if client else None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -419,6 +437,14 @@ class WebsocketHandler(Receiver):
|
||||
}
|
||||
)
|
||||
|
||||
def handle_audio_queue(self, emission: Emission):
|
||||
self.queue_put(
|
||||
{
|
||||
"type": "audio_queue",
|
||||
"data": emission.data,
|
||||
}
|
||||
)
|
||||
|
||||
def handle_request_input(self, emission: Emission):
|
||||
self.waiting_for_input = True
|
||||
|
||||
|
||||
@@ -43,6 +43,10 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
async_signals.register("game_loop_start")
|
||||
async_signals.register("game_loop")
|
||||
async_signals.register("game_loop_actor_iter")
|
||||
async_signals.register("game_loop_new_message")
|
||||
|
||||
class Character:
|
||||
"""
|
||||
@@ -523,8 +527,6 @@ class Player(Actor):
|
||||
|
||||
return message
|
||||
|
||||
async_signals.register("game_loop_start")
|
||||
async_signals.register("game_loop")
|
||||
|
||||
class Scene(Emitter):
|
||||
"""
|
||||
@@ -575,6 +577,8 @@ class Scene(Emitter):
|
||||
"character_state": signal("character_state"),
|
||||
"game_loop": async_signals.get("game_loop"),
|
||||
"game_loop_start": async_signals.get("game_loop_start"),
|
||||
"game_loop_actor_iter": async_signals.get("game_loop_actor_iter"),
|
||||
"game_loop_new_message": async_signals.get("game_loop_new_message"),
|
||||
}
|
||||
|
||||
self.setup_emitter(scene=self)
|
||||
@@ -701,6 +705,12 @@ class Scene(Emitter):
|
||||
messages=messages,
|
||||
)
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for message in messages:
|
||||
loop.run_until_complete(self.signals["game_loop_new_message"].send(
|
||||
events.GameLoopNewMessageEvent(scene=self, event_type="game_loop_new_message", message=message)
|
||||
))
|
||||
|
||||
def push_archive(self, entry: data_objects.ArchiveEntry):
|
||||
|
||||
@@ -1066,7 +1076,9 @@ class Scene(Emitter):
|
||||
new_message = await narrator.agent.narrate_character(character)
|
||||
elif source == "narrate_query":
|
||||
new_message = await narrator.agent.narrate_query(arg)
|
||||
|
||||
elif source == "narrate_dialogue":
|
||||
character = self.get_character(arg)
|
||||
new_message = await narrator.agent.narrate_after_dialogue(character)
|
||||
else:
|
||||
fn = getattr(narrator.agent, source, None)
|
||||
if not fn:
|
||||
@@ -1172,7 +1184,7 @@ class Scene(Emitter):
|
||||
},
|
||||
)
|
||||
|
||||
self.log.debug("scene_status", scene=self.name, scene_time=self.ts, saved=self.saved)
|
||||
self.log.debug("scene_status", scene=self.name, scene_time=self.ts, human_ts=util.iso8601_duration_to_human(self.ts, suffix=""), saved=self.saved)
|
||||
|
||||
def set_environment(self, environment: str):
|
||||
"""
|
||||
@@ -1185,6 +1197,7 @@ class Scene(Emitter):
|
||||
"""
|
||||
Accepts an iso6801 duration string and advances the scene's world state by that amount
|
||||
"""
|
||||
log.debug("advance_time", ts=ts, scene_ts=self.ts, duration=isodate.parse_duration(ts), scene_duration=isodate.parse_duration(self.ts))
|
||||
|
||||
self.ts = isodate.duration_isoformat(
|
||||
isodate.parse_duration(self.ts) + isodate.parse_duration(ts)
|
||||
@@ -1207,9 +1220,12 @@ class Scene(Emitter):
|
||||
if self.archived_history[i].get("ts"):
|
||||
self.ts = self.archived_history[i]["ts"]
|
||||
break
|
||||
|
||||
end = self.archived_history[-1].get("end", 0)
|
||||
else:
|
||||
end = 0
|
||||
|
||||
|
||||
for message in self.history:
|
||||
for message in self.history[end:]:
|
||||
if isinstance(message, TimePassageMessage):
|
||||
self.advance_time(message.ts)
|
||||
|
||||
@@ -1339,6 +1355,10 @@ class Scene(Emitter):
|
||||
if await command.execute(message):
|
||||
break
|
||||
await self.call_automated_actions()
|
||||
|
||||
await self.signals["game_loop_actor_iter"].send(
|
||||
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
|
||||
)
|
||||
continue
|
||||
|
||||
self.saved = False
|
||||
@@ -1350,6 +1370,10 @@ class Scene(Emitter):
|
||||
emit(
|
||||
"character", item, character=actor.character
|
||||
)
|
||||
|
||||
await self.signals["game_loop_actor_iter"].send(
|
||||
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
|
||||
)
|
||||
|
||||
self.emit_status()
|
||||
|
||||
|
||||
@@ -303,6 +303,9 @@ def strip_partial_sentences(text:str) -> str:
|
||||
# Sentence ending characters
|
||||
sentence_endings = ['.', '!', '?', '"', "*"]
|
||||
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Check if the last character is already a sentence ending
|
||||
if text[-1] in sentence_endings:
|
||||
return text
|
||||
@@ -487,30 +490,43 @@ def clean_attribute(attribute: str) -> str:
|
||||
|
||||
|
||||
|
||||
|
||||
def duration_to_timedelta(duration):
|
||||
"""Convert an isodate.Duration object to a datetime.timedelta object."""
|
||||
"""Convert an isodate.Duration object or a datetime.timedelta object to a datetime.timedelta object."""
|
||||
# Check if the duration is already a timedelta object
|
||||
if isinstance(duration, datetime.timedelta):
|
||||
return duration
|
||||
|
||||
# Check if the duration is an isodate.Duration object with a tdelta attribute
|
||||
if hasattr(duration, 'tdelta'):
|
||||
return duration.tdelta
|
||||
|
||||
# If it's an isodate.Duration object with separate year, month, day, hour, minute, second attributes
|
||||
days = int(duration.years) * 365 + int(duration.months) * 30 + int(duration.days)
|
||||
return datetime.timedelta(days=days)
|
||||
seconds = int(duration.hours) * 3600 + int(duration.minutes) * 60 + int(duration.seconds)
|
||||
return datetime.timedelta(days=days, seconds=seconds)
|
||||
|
||||
def timedelta_to_duration(delta):
|
||||
"""Convert a datetime.timedelta object to an isodate.Duration object."""
|
||||
# Extract days and convert to years, months, and days
|
||||
days = delta.days
|
||||
years = days // 365
|
||||
days %= 365
|
||||
months = days // 30
|
||||
days %= 30
|
||||
return isodate.duration.Duration(years=years, months=months, days=days)
|
||||
# Convert remaining seconds to hours, minutes, and seconds
|
||||
seconds = delta.seconds
|
||||
hours = seconds // 3600
|
||||
seconds %= 3600
|
||||
minutes = seconds // 60
|
||||
seconds %= 60
|
||||
return isodate.Duration(years=years, months=months, days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
|
||||
def parse_duration_to_isodate_duration(duration_str):
|
||||
"""Parse ISO 8601 duration string and ensure the result is an isodate.Duration."""
|
||||
parsed_duration = isodate.parse_duration(duration_str)
|
||||
if isinstance(parsed_duration, datetime.timedelta):
|
||||
days = parsed_duration.days
|
||||
years = days // 365
|
||||
days %= 365
|
||||
months = days // 30
|
||||
days %= 30
|
||||
return isodate.duration.Duration(years=years, months=months, days=days)
|
||||
return timedelta_to_duration(parsed_duration)
|
||||
return parsed_duration
|
||||
|
||||
def iso8601_diff(duration_str1, duration_str2):
|
||||
@@ -530,40 +546,50 @@ def iso8601_diff(duration_str1, duration_str2):
|
||||
|
||||
return difference
|
||||
|
||||
def iso8601_duration_to_human(iso_duration, suffix:str=" ago"):
|
||||
# Parse the ISO8601 duration string into an isodate duration object
|
||||
def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
|
||||
|
||||
if isinstance(iso_duration, isodate.Duration):
|
||||
duration = iso_duration
|
||||
else:
|
||||
# Parse the ISO8601 duration string into an isodate duration object
|
||||
if not isinstance(iso_duration, isodate.Duration):
|
||||
duration = isodate.parse_duration(iso_duration)
|
||||
else:
|
||||
duration = iso_duration
|
||||
|
||||
# Extract years, months, days, and the time part as seconds
|
||||
years, months, days, hours, minutes, seconds = 0, 0, 0, 0, 0, 0
|
||||
|
||||
if isinstance(duration, isodate.Duration):
|
||||
years = duration.years
|
||||
months = duration.months
|
||||
days = duration.days
|
||||
seconds = duration.tdelta.total_seconds()
|
||||
else:
|
||||
years, months = 0, 0
|
||||
hours = duration.tdelta.seconds // 3600
|
||||
minutes = (duration.tdelta.seconds % 3600) // 60
|
||||
seconds = duration.tdelta.seconds % 60
|
||||
elif isinstance(duration, datetime.timedelta):
|
||||
days = duration.days
|
||||
seconds = duration.total_seconds() - days * 86400 # Extract time-only part
|
||||
hours = duration.seconds // 3600
|
||||
minutes = (duration.seconds % 3600) // 60
|
||||
seconds = duration.seconds % 60
|
||||
|
||||
hours, seconds = divmod(seconds, 3600)
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
|
||||
# Adjust for cases where duration is a timedelta object
|
||||
# Convert days to weeks and days if applicable
|
||||
weeks, days = divmod(days, 7)
|
||||
|
||||
# Build the human-readable components
|
||||
components = []
|
||||
if years:
|
||||
components.append(f"{years} Year{'s' if years > 1 else ''}")
|
||||
if months:
|
||||
components.append(f"{months} Month{'s' if months > 1 else ''}")
|
||||
if weeks:
|
||||
components.append(f"{weeks} Week{'s' if weeks > 1 else ''}")
|
||||
if days:
|
||||
components.append(f"{days} Day{'s' if days > 1 else ''}")
|
||||
if hours:
|
||||
components.append(f"{int(hours)} Hour{'s' if hours > 1 else ''}")
|
||||
components.append(f"{hours} Hour{'s' if hours > 1 else ''}")
|
||||
if minutes:
|
||||
components.append(f"{int(minutes)} Minute{'s' if minutes > 1 else ''}")
|
||||
components.append(f"{minutes} Minute{'s' if minutes > 1 else ''}")
|
||||
if seconds:
|
||||
components.append(f"{int(seconds)} Second{'s' if seconds > 1 else ''}")
|
||||
components.append(f"{seconds} Second{'s' if seconds > 1 else ''}")
|
||||
|
||||
# Construct the human-readable string
|
||||
if len(components) > 1:
|
||||
@@ -573,7 +599,7 @@ def iso8601_duration_to_human(iso_duration, suffix:str=" ago"):
|
||||
human_str = components[0]
|
||||
else:
|
||||
human_str = "Moments"
|
||||
|
||||
|
||||
return f"{human_str}{suffix}"
|
||||
|
||||
def iso8601_diff_to_human(start, end):
|
||||
@@ -581,6 +607,7 @@ def iso8601_diff_to_human(start, end):
|
||||
return ""
|
||||
|
||||
diff = iso8601_diff(start, end)
|
||||
|
||||
return iso8601_duration_to_human(diff)
|
||||
|
||||
|
||||
@@ -779,7 +806,11 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
|
||||
lines = []
|
||||
|
||||
for _line in line.split("\n"):
|
||||
_line = ensure_dialog_line_format(_line)
|
||||
try:
|
||||
_line = ensure_dialog_line_format(_line)
|
||||
except Exception as exc:
|
||||
log.error("ensure_dialog_format", msg="Error ensuring dialog line format", line=_line, exc_info=exc)
|
||||
pass
|
||||
|
||||
lines.append(_line)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from talemate.emit import emit
|
||||
import structlog
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import talemate.instance as instance
|
||||
@@ -59,7 +60,8 @@ class WorldState(BaseModel):
|
||||
world_state = await self.agent.request_world_state()
|
||||
except Exception as e:
|
||||
self.emit()
|
||||
raise e
|
||||
log.error("world_state.request_update", error=e, traceback=traceback.format_exc())
|
||||
return
|
||||
|
||||
previous_characters = self.characters
|
||||
previous_items = self.items
|
||||
|
||||
@@ -7,11 +7,12 @@
|
||||
size="14"></v-progress-circular>
|
||||
<v-icon v-else-if="agent.status === 'uninitialized'" color="orange" size="14">mdi-checkbox-blank-circle</v-icon>
|
||||
<v-icon v-else-if="agent.status === 'disabled'" color="grey-darken-2" size="14">mdi-checkbox-blank-circle</v-icon>
|
||||
<v-icon v-else-if="agent.status === 'error'" color="red" size="14">mdi-checkbox-blank-circle</v-icon>
|
||||
<v-icon v-else color="green" size="14">mdi-checkbox-blank-circle</v-icon>
|
||||
<span class="ml-1" v-if="agent.label"> {{ agent.label }}</span>
|
||||
<span class="ml-1" v-else> {{ agent.name }}</span>
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
<v-list-item-subtitle class="text-caption">
|
||||
{{ agent.client }}
|
||||
</v-list-item-subtitle>
|
||||
<v-chip class="mr-1" v-if="agent.status === 'disabled'" size="x-small">Disabled</v-chip>
|
||||
@@ -65,7 +66,10 @@ export default {
|
||||
for(let i = 0; i < this.state.agents.length; i++) {
|
||||
let agent = this.state.agents[i];
|
||||
|
||||
if(agent.status === 'warning' || agent.status === 'error') {
|
||||
if(!agent.data.requires_llm_client)
|
||||
continue
|
||||
|
||||
if(agent.status === 'warning' || agent.status === 'error' || agent.status === 'uninitialized') {
|
||||
console.log("agents: configuration required (1)", agent.status)
|
||||
return true;
|
||||
}
|
||||
@@ -99,7 +103,6 @@ export default {
|
||||
} else {
|
||||
this.state.agents[index] = agent;
|
||||
}
|
||||
this.state.dialog = false;
|
||||
this.$emit('agents-updated', this.state.agents);
|
||||
},
|
||||
editAgent(index) {
|
||||
|
||||
@@ -120,7 +120,7 @@ export default {
|
||||
this.state.currentClient = {
|
||||
name: 'TextGenWebUI',
|
||||
type: 'textgenwebui',
|
||||
apiUrl: 'http://localhost:5000/api',
|
||||
apiUrl: 'http://localhost:5000',
|
||||
model_name: '',
|
||||
max_token_length: 4096,
|
||||
};
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
</v-col>
|
||||
<v-col cols="3" class="text-right">
|
||||
<v-checkbox :label="enabledLabel()" hide-details density="compact" color="green" v-model="agent.enabled"
|
||||
v-if="agent.data.has_toggle"></v-checkbox>
|
||||
v-if="agent.data.has_toggle" @update:modelValue="save(false)"></v-checkbox>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
</v-card-title>
|
||||
<v-card-text class="scrollable-content">
|
||||
<v-select v-model="agent.client" :items="agent.data.client" label="Client"></v-select>
|
||||
<v-select v-if="agent.data.requires_llm_client" v-model="agent.client" :items="agent.data.client" label="Client" @update:modelValue="save(false)"></v-select>
|
||||
|
||||
<v-alert type="warning" variant="tonal" density="compact" v-if="agent.data.experimental">
|
||||
This agent is currently experimental and may significantly decrease performance and / or require
|
||||
@@ -27,27 +27,25 @@
|
||||
|
||||
<v-card v-for="(action, key) in agent.actions" :key="key" density="compact">
|
||||
<v-card-subtitle>
|
||||
<v-checkbox :label="agent.data.actions[key].label" hide-details density="compact" color="green" v-model="action.enabled"></v-checkbox>
|
||||
<v-checkbox v-if="!actionAlwaysEnabled(key)" :label="agent.data.actions[key].label" hide-details density="compact" color="green" v-model="action.enabled" @update:modelValue="save(false)"></v-checkbox>
|
||||
</v-card-subtitle>
|
||||
<v-card-text>
|
||||
{{ agent.data.actions[key].description }}
|
||||
<div v-if="!actionAlwaysEnabled(key)">
|
||||
{{ agent.data.actions[key].description }}
|
||||
</div>
|
||||
<div v-for="(action_config, config_key) in agent.data.actions[key].config" :key="config_key">
|
||||
<div v-if="action.enabled">
|
||||
<!-- render config widgets based on action_config.type (int, str, bool, float) -->
|
||||
<v-text-field v-if="action_config.type === 'text'" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact"></v-text-field>
|
||||
<v-slider v-if="action_config.type === 'number' && action_config.step !== null" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" :min="action_config.min" :max="action_config.max" :step="action_config.step" density="compact" thumb-label></v-slider>
|
||||
<v-checkbox v-if="action_config.type === 'bool'" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact"></v-checkbox>
|
||||
<v-text-field v-if="action_config.type === 'text' && action_config.choices === null" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact" @update:modelValue="save(true)"></v-text-field>
|
||||
<v-autocomplete v-else-if="action_config.type === 'text' && action_config.choices !== null" v-model="action.config[config_key].value" :items="action_config.choices" :label="action_config.label" :hint="action_config.description" density="compact" item-title="label" item-value="value" @update:modelValue="save(false)"></v-autocomplete>
|
||||
<v-slider v-if="action_config.type === 'number' && action_config.step !== null" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" :min="action_config.min" :max="action_config.max" :step="action_config.step" density="compact" thumb-label @update:modelValue="save(true)"></v-slider>
|
||||
<v-checkbox v-if="action_config.type === 'bool'" v-model="action.config[config_key].value" :label="action_config.label" :hint="action_config.description" density="compact" @update:modelValue="save(false)"></v-checkbox>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" @click="close">Close</v-btn>
|
||||
<v-btn color="primary" @click="save">Save</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
@@ -58,9 +56,10 @@ export default {
|
||||
dialog: Boolean,
|
||||
formTitle: String
|
||||
},
|
||||
inject: ['state'],
|
||||
inject: ['state', 'getWebsocket'],
|
||||
data() {
|
||||
return {
|
||||
saveTimeout: null,
|
||||
localDialog: this.state.dialog,
|
||||
agent: { ...this.state.currentAgent }
|
||||
};
|
||||
@@ -90,12 +89,32 @@ export default {
|
||||
return 'Disabled';
|
||||
}
|
||||
},
|
||||
actionAlwaysEnabled(action) {
|
||||
if (action.charAt(0) === '_') {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
|
||||
close() {
|
||||
this.$emit('update:dialog', false);
|
||||
},
|
||||
save() {
|
||||
this.$emit('save', this.agent);
|
||||
this.close();
|
||||
save(delayed = false) {
|
||||
console.log("save", delayed);
|
||||
if(!delayed) {
|
||||
this.$emit('save', this.agent);
|
||||
return;
|
||||
}
|
||||
|
||||
if(this.saveTimeout !== null)
|
||||
clearTimeout(this.saveTimeout);
|
||||
|
||||
this.saveTimeout = setTimeout(() => {
|
||||
this.$emit('save', this.agent);
|
||||
}, 500);
|
||||
|
||||
//this.$emit('save', this.agent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
96
talemate_frontend/src/components/AudioQueue.vue
Normal file
96
talemate_frontend/src/components/AudioQueue.vue
Normal file
@@ -0,0 +1,96 @@
|
||||
<template>
|
||||
<div class="audio-queue">
|
||||
<span>{{ queue.length }} sound(s) queued</span>
|
||||
<v-icon :color="isPlaying ? 'green' : ''" v-if="!isMuted" @click="toggleMute">mdi-volume-high</v-icon>
|
||||
<v-icon :color="isPlaying ? 'red' : ''" v-else @click="toggleMute">mdi-volume-off</v-icon>
|
||||
<v-icon class="ml-1" @click="stopAndClear">mdi-stop-circle-outline</v-icon>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
export default {
|
||||
name: 'AudioQueue',
|
||||
data() {
|
||||
return {
|
||||
queue: [],
|
||||
audioContext: null,
|
||||
isPlaying: false,
|
||||
isMuted: false,
|
||||
currentSource: null
|
||||
};
|
||||
},
|
||||
inject: ['getWebsocket', 'registerMessageHandler'],
|
||||
created() {
|
||||
this.audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
this.registerMessageHandler(this.handleMessage);
|
||||
},
|
||||
methods: {
|
||||
handleMessage(data) {
|
||||
if (data.type === 'audio_queue') {
|
||||
console.log('Received audio queue message', data)
|
||||
this.addToQueue(data.data.audio_data);
|
||||
}
|
||||
},
|
||||
addToQueue(base64Sound) {
|
||||
const soundBuffer = this.base64ToArrayBuffer(base64Sound);
|
||||
this.queue.push(soundBuffer);
|
||||
this.playNextSound();
|
||||
},
|
||||
base64ToArrayBuffer(base64) {
|
||||
const binaryString = window.atob(base64);
|
||||
const len = binaryString.length;
|
||||
const bytes = new Uint8Array(len);
|
||||
for (let i = 0; i < len; i++) {
|
||||
bytes[i] = binaryString.charCodeAt(i);
|
||||
}
|
||||
return bytes.buffer;
|
||||
},
|
||||
playNextSound() {
|
||||
if (this.isPlaying || this.queue.length === 0) {
|
||||
return;
|
||||
}
|
||||
this.isPlaying = true;
|
||||
const soundBuffer = this.queue.shift();
|
||||
this.audioContext.decodeAudioData(soundBuffer, (buffer) => {
|
||||
const source = this.audioContext.createBufferSource();
|
||||
source.buffer = buffer;
|
||||
this.currentSource = source;
|
||||
if (!this.isMuted) {
|
||||
source.connect(this.audioContext.destination);
|
||||
}
|
||||
source.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.playNextSound();
|
||||
};
|
||||
source.start(0);
|
||||
}, (error) => {
|
||||
console.error('Error with decoding audio data', error);
|
||||
});
|
||||
},
|
||||
toggleMute() {
|
||||
this.isMuted = !this.isMuted;
|
||||
if (this.isMuted && this.currentSource) {
|
||||
this.currentSource.disconnect(this.audioContext.destination);
|
||||
} else if (this.currentSource) {
|
||||
this.currentSource.connect(this.audioContext.destination);
|
||||
}
|
||||
},
|
||||
stopAndClear() {
|
||||
if (this.currentSource) {
|
||||
this.currentSource.stop();
|
||||
this.currentSource.disconnect();
|
||||
this.currentSource = null;
|
||||
}
|
||||
this.queue = [];
|
||||
this.isPlaying = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.audio-queue {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
</style>
|
||||
@@ -8,7 +8,7 @@
|
||||
<v-container>
|
||||
<v-row>
|
||||
<v-col cols="6">
|
||||
<v-select v-model="client.type" :items="['openai', 'textgenwebui']" label="Client Type"></v-select>
|
||||
<v-select v-model="client.type" :disabled="!typeEditable()" :items="['openai', 'textgenwebui', 'lmstudio']" label="Client Type"></v-select>
|
||||
</v-col>
|
||||
<v-col cols="6">
|
||||
<v-text-field v-model="client.name" label="Client Name"></v-text-field>
|
||||
@@ -17,13 +17,13 @@
|
||||
</v-row>
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
<v-text-field v-model="client.apiUrl" v-if="client.type === 'textgenwebui'" label="API URL"></v-text-field>
|
||||
<v-text-field v-model="client.apiUrl" v-if="isLocalApiClient(client)" label="API URL"></v-text-field>
|
||||
<v-select v-model="client.model" v-if="client.type === 'openai'" :items="['gpt-4-1106-preview', 'gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']" label="Model"></v-select>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-row>
|
||||
<v-col cols="6">
|
||||
<v-text-field v-model="client.max_token_length" v-if="client.type === 'textgenwebui'" type="number" label="Context Length"></v-text-field>
|
||||
<v-text-field v-model="client.max_token_length" v-if="isLocalApiClient(client)" type="number" label="Context Length"></v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-container>
|
||||
@@ -68,12 +68,18 @@ export default {
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
typeEditable() {
|
||||
return this.state.formTitle === 'Add Client';
|
||||
},
|
||||
close() {
|
||||
this.$emit('update:dialog', false);
|
||||
},
|
||||
save() {
|
||||
this.$emit('save', this.client); // Emit save event with client object
|
||||
this.close();
|
||||
},
|
||||
isLocalApiClient(client) {
|
||||
return client.type === 'textgenwebui' || client.type === 'lmstudio';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +80,12 @@ export default {
|
||||
this.getWebsocket().send(JSON.stringify({ type: 'request_scenes_list', query: this.sceneSearchInput }));
|
||||
},
|
||||
loadCreative() {
|
||||
if(this.sceneSaved === false) {
|
||||
if(!confirm("The current scene is not saved. Are you sure you want to load a new scene?")) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
this.loading = true;
|
||||
this.getWebsocket().send(JSON.stringify({ type: 'load_scene', file_path: "environment:creative" }));
|
||||
},
|
||||
|
||||
@@ -75,6 +75,8 @@
|
||||
<span v-if="connecting" class="ml-1"><v-icon class="mr-1">mdi-progress-helper</v-icon>connecting</span>
|
||||
<span v-else-if="connected" class="ml-1"><v-icon class="mr-1" color="green" size="14">mdi-checkbox-blank-circle</v-icon>connected</span>
|
||||
<span v-else class="ml-1"><v-icon class="mr-1">mdi-progress-close</v-icon>disconnected</span>
|
||||
<v-divider class="ml-1 mr-1" vertical></v-divider>
|
||||
<AudioQueue ref="audioQueue" />
|
||||
<v-spacer></v-spacer>
|
||||
<span v-if="version !== null">v{{ version }}</span>
|
||||
<span v-if="configurationRequired()">
|
||||
@@ -161,6 +163,7 @@ import SceneHistory from './SceneHistory.vue';
|
||||
import CreativeEditor from './CreativeEditor.vue';
|
||||
import AppConfig from './AppConfig.vue';
|
||||
import DebugTools from './DebugTools.vue';
|
||||
import AudioQueue from './AudioQueue.vue';
|
||||
|
||||
export default {
|
||||
components: {
|
||||
@@ -177,6 +180,7 @@ export default {
|
||||
CreativeEditor,
|
||||
AppConfig,
|
||||
DebugTools,
|
||||
AudioQueue,
|
||||
},
|
||||
name: 'TalemateApp',
|
||||
data() {
|
||||
|
||||
4
templates/llm-prompt/Cat.jinja2
Normal file
4
templates/llm-prompt/Cat.jinja2
Normal file
@@ -0,0 +1,4 @@
|
||||
{{ system_message }}
|
||||
|
||||
### Instruction:
|
||||
{{ set_response(prompt, "\n\n### Response:\n") }}
|
||||
3
templates/llm-prompt/Nous-Capybara.jinja2
Normal file
3
templates/llm-prompt/Nous-Capybara.jinja2
Normal file
@@ -0,0 +1,3 @@
|
||||
USER:
|
||||
{{ system_message }}
|
||||
{{ set_response(prompt, "\nASSISTANT:") }}
|
||||
1
templates/llm-prompt/OrionStar.jinja2
Normal file
1
templates/llm-prompt/OrionStar.jinja2
Normal file
@@ -0,0 +1 @@
|
||||
Human: {{ system_message }} {{ set_response(prompt, "\n\nAssistant:") }}
|
||||
4
templates/llm-prompt/Psyfighter2.jinja2
Normal file
4
templates/llm-prompt/Psyfighter2.jinja2
Normal file
@@ -0,0 +1,4 @@
|
||||
{{ system_message }}
|
||||
|
||||
### Instruction:
|
||||
{{ set_response(prompt, "\n\n### Response:\n") }}
|
||||
2
templates/llm-prompt/Tess-Medium.jinja2
Normal file
2
templates/llm-prompt/Tess-Medium.jinja2
Normal file
@@ -0,0 +1,2 @@
|
||||
SYSTEM: {{ system_message }}
|
||||
USER: {{ set_response(prompt, "\nASSISTANT: ") }}
|
||||
4
templates/llm-prompt/dolphin-2.2.1-mistral.jinja2
Normal file
4
templates/llm-prompt/dolphin-2.2.1-mistral.jinja2
Normal file
@@ -0,0 +1,4 @@
|
||||
<|im_start|>system
|
||||
{{ system_message }}<|im_end|>
|
||||
<|im_start|>user
|
||||
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}
|
||||
4
templates/llm-prompt/dolphin-2_2-yi.jinja2
Normal file
4
templates/llm-prompt/dolphin-2_2-yi.jinja2
Normal file
@@ -0,0 +1,4 @@
|
||||
<|im_start|>system
|
||||
{{ system_message }}<|im_end|>
|
||||
<|im_start|>user
|
||||
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}
|
||||
4
templates/llm-prompt/orca-2.jinja2
Normal file
4
templates/llm-prompt/orca-2.jinja2
Normal file
@@ -0,0 +1,4 @@
|
||||
<|im_start|>system
|
||||
{{ system_message }}<|im_end|>
|
||||
<|im_start|>user
|
||||
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}
|
||||
@@ -6,5 +6,5 @@ call talemate_env\Scripts\activate
|
||||
REM use poetry to install dependencies
|
||||
python -m poetry install
|
||||
|
||||
echo Virtual environment re-created.
|
||||
echo Virtual environment updated
|
||||
pause
|
||||
|
||||
Reference in New Issue
Block a user