Compare commits

...

8 Commits

Author SHA1 Message Date
veguAI
4ba635497b Prep 0.18.1 (#72)
* prevent client ctx from being unset

* fix issue with LMStudio client ctx size not sticking

* 0.18.1
2024-01-31 09:46:51 +02:00
veguAI
bdbf14c1ed Update README.md 2024-01-31 01:47:52 +02:00
veguAI
c340fc085c Update README.md 2024-01-31 01:47:29 +02:00
veguAI
94f8d0f242 Update README.md 2024-01-31 01:00:59 +02:00
veguAI
1d8a9b113c Update README.md 2024-01-30 08:08:45 +02:00
vegu-ai-tools
1837796852 readme 2024-01-26 14:41:59 +02:00
vegu-ai-tools
c5c53c056e readme updates 2024-01-26 13:29:21 +02:00
veguAI
f1b1190f0b linting (#63) 2024-01-26 12:46:55 +02:00
93 changed files with 6971 additions and 5091 deletions

View File

@@ -7,13 +7,16 @@ Allows you to play roleplay scenarios with large language models.
|------------------------------------------|------------------------------------------|
|![Screenshot 1](docs/img/0.17.0/ss-4.png)|![Screenshot 2](docs/img/0.17.0/ss-3.png)|
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio.**
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio. 0.18.0 also adds support for generic OpenAI api implementations, but generation quality on that will vary.**
This means you need to either have:
- an [OpenAI](https://platform.openai.com/overview) api key
- OR setup local (or remote via runpod) LLM inference via one of these options:
- setup local (or remote via runpod) LLM inference via:
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
- [LMStudio](https://lmstudio.ai/)
- Any other OpenAI api implementation that implements the v1/completions endpoint
- tested llamacpp with the `api_like_OAI.py` wrapper
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
## Current features
@@ -35,6 +38,7 @@ This means you need to either have:
- Automatically keep track and reinforce selected character and world truths / states.
- narrative tools
- creative tools
- manage multiple NPCs
- AI backed character creation with template support (jinja2)
- AI backed scenario creation
- context managegement
@@ -77,7 +81,7 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
### Windows
1. Download and install Python 3.10 or Python 3.11 from the [official Python website](https://www.python.org/downloads/windows/). :warning: python3.12 is currently not supported.
1. Download and install Node.js from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm.
1. Download and install Node.js v20 from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm. :warning: v21 is currently not supported.
1. Download the Talemate project to your local machine. Download from [the Releases page](https://github.com/vegu-ai/talemate/releases).
1. Unpack the download and run `install.bat` by double clicking it. This will set up the project on your local machine.
1. Once the installation is complete, you can start the backend and frontend servers by running `start.bat`.
@@ -87,45 +91,14 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
`python 3.10` or `python 3.11` is required. :warning: `python 3.12` not supported yet.
`nodejs v19 or v20` :warning: `v21` not supported yet.
1. `git clone git@github.com:vegu-ai/talemate`
1. `cd talemate`
1. `source install.sh`
1. Start the backend: `python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
1. Open a new terminal, navigate to the `talemate_frontend` directory, and start the frontend server by running `npm run serve`.
## Configuration
### OpenAI
To set your openai api key, open `config.yaml` in any text editor and uncomment / add
```yaml
openai:
api_key: sk-my-api-key-goes-here
```
You will need to restart the backend for this change to take effect.
### RunPod
To set your runpod api key, open `config.yaml` in any text editor and uncomment / add
```yaml
runpod:
api_key: my-api-key-goes-here
```
You will need to restart the backend for this change to take effect.
Once the api key is set Pods loaded from text-generation-webui templates (or the bloke's runpod llm template) will be autoamtically added to your client list in talemate.
**ATTENTION**: Talemate is not a suitable for way for you to determine whether your pod is currently running or not. **Always** check the runpod dashboard to see if your pod is running or not.
## Recommended Models
(as of2023.10.25)
Any of the top models in any of the size classes here should work well:
https://www.reddit.com/r/LocalLLaMA/comments/17fhp9k/huge_llm_comparisontest_39_models_tested_7b70b/
## Connecting to an LLM
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
@@ -140,13 +113,33 @@ In the modal if you're planning to connect to text-generation-webui, you can lik
![Add client modal](docs/img/client-setup-0.13.png)
#### Recommended Models
Any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
https://www.reddit.com/r/LocalLLaMA/comments/18yp9u4/llm_comparisontest_api_edition_gpt4_vs_gemini_vs/
### OpenAI
If you want to add an OpenAI client, just change the client type and select the apropriate model.
![Add client modal](docs/img/add-client-modal-openai.png)
### Ready to go
If you are setting this up for the first time, you should now see the client, but it will have a red dot next to it, stating that it requires an API key.
![OpenAI API Key missing](docs/img/0.18.0/openai-api-key-1.png)
Click the `SET API KEY` button. This will open a modal where you can enter your API key.
![OpenAI API Key missing](docs/img/0.18.0/openai-api-key-2.png)
Click `Save` and after a moment the client should have a green dot next to it, indicating that it is ready to go.
![OpenAI API Key set](docs/img/0.18.0/openai-api-key-3.png)
## Ready to go
You will know you are good to go when the client and all the agents have a green dot next to them.

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 KiB

View File

@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
[tool.poetry]
name = "talemate"
version = "0.18.0"
version = "0.18.1"
description = "AI-backed roleplay and narrative tools"
authors = ["FinalWombat"]
license = "GNU Affero General Public License v3.0"

View File

@@ -2,4 +2,4 @@ from .agents import Agent
from .client import TextGeneratorWebuiClient
from .tale_mate import *
VERSION = "0.18.0"
VERSION = "0.18.1"

View File

@@ -1,11 +1,11 @@
from .base import Agent
from .creator import CreatorAgent
from .conversation import ConversationAgent
from .creator import CreatorAgent
from .director import DirectorAgent
from .editor import EditorAgent
from .memory import ChromaDBMemoryAgent, MemoryAgent
from .narrator import NarratorAgent
from .registry import AGENT_CLASSES, get_agent_class, register
from .summarize import SummarizeAgent
from .editor import EditorAgent
from .tts import TTSAgent
from .world_state import WorldStateAgent
from .tts import TTSAgent

View File

@@ -1,21 +1,21 @@
from __future__ import annotations
import asyncio
import dataclasses
import re
from abc import ABC
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import pydantic
import structlog
from blinker import signal
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
from talemate.agents.context import ActiveAgent
from talemate.emit import emit
from talemate.events import GameLoopStartEvent
import talemate.emit.async_signals
import dataclasses
import pydantic
import structlog
__all__ = [
"Agent",
@@ -37,26 +37,27 @@ class AgentActionConfig(pydantic.BaseModel):
scope: str = "global"
choices: Union[list[dict[str, str]], None] = None
note: Union[str, None] = None
class Config:
arbitrary_types_allowed = True
class AgentAction(pydantic.BaseModel):
enabled: bool = True
label: str
description: str = ""
config: Union[dict[str, AgentActionConfig], None] = None
def set_processing(fn):
"""
decorator that emits the agent status as processing while the function
is running.
Done via a try - final block to ensure the status is reset even if
the function fails.
"""
async def wrapper(self, *args, **kwargs):
with ActiveAgent(self, fn):
try:
@@ -69,9 +70,9 @@ def set_processing(fn):
# not sure why this happens
# some concurrency error?
log.error("error emitting agent status", exc=exc)
wrapper.__name__ = fn.__name__
return wrapper
@@ -97,16 +98,14 @@ class Agent(ABC):
def verbose_name(self):
return self.agent_type.capitalize()
@property
def ready(self):
if not getattr(self.client, "enabled", True):
return False
if self.client and self.client.current_status in ["error", "warning"]:
return False
return self.client is not None
@property
@@ -123,20 +122,20 @@ class Agent(ABC):
# by default, agents are enabled, an agent class that
# is disableable should override this property
return True
@property
def disable(self):
# by default, agents are enabled, an agent class that
# is disableable should override this property to
# is disableable should override this property to
# disable the agent
pass
@property
def has_toggle(self):
# by default, agents do not have toggles to enable / disable
# an agent class that is disableable should override this property
return False
@property
def experimental(self):
# by default, agents are not experimental, an agent class that
@@ -153,85 +152,92 @@ class Agent(ABC):
"requires_llm_client": cls.requires_llm_client,
}
actions = getattr(agent, "actions", None)
if actions:
config_options["actions"] = {k: v.model_dump() for k, v in actions.items()}
else:
config_options["actions"] = {}
return config_options
def apply_config(self, *args, **kwargs):
if self.has_toggle and "enabled" in kwargs:
self.is_enabled = kwargs.get("enabled", False)
if not getattr(self, "actions", None):
return
for action_key, action in self.actions.items():
if not kwargs.get("actions"):
continue
action.enabled = kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
action.enabled = (
kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
)
if not action.config:
continue
for config_key, config in action.config.items():
try:
config.value = kwargs.get("actions", {}).get(action_key, {}).get("config", {}).get(config_key, {}).get("value", config.value)
config.value = (
kwargs.get("actions", {})
.get(action_key, {})
.get("config", {})
.get(config_key, {})
.get("value", config.value)
)
except AttributeError:
pass
async def on_game_loop_start(self, event:GameLoopStartEvent):
async def on_game_loop_start(self, event: GameLoopStartEvent):
"""
Finds all ActionConfigs that have a scope of "scene" and resets them to their default values
"""
if not getattr(self, "actions", None):
return
for _, action in self.actions.items():
if not action.config:
continue
for _, config in action.config.items():
if config.scope == "scene":
# if default_value is None, just use the `type` of the current
# if default_value is None, just use the `type` of the current
# value
if config.default_value is None:
default_value = type(config.value)()
else:
default_value = config.default_value
log.debug("resetting config", config=config, default_value=default_value)
log.debug(
"resetting config", config=config, default_value=default_value
)
config.value = default_value
await self.emit_status()
async def emit_status(self, processing: bool = None):
# should keep a count of processing requests, and when the
# number is 0 status is "idle", if the number is greater than 0
# status is "busy"
#
# increase / decrease based on value of `processing`
if getattr(self, "processing", None) is None:
self.processing = 0
if not processing:
self.processing -= 1
self.processing = max(0, self.processing)
else:
self.processing += 1
status = "busy" if self.processing > 0 else "idle"
if not self.enabled:
status = "disabled"
emit(
"agent_status",
message=self.verbose_name or "",
@@ -245,8 +251,9 @@ class Agent(ABC):
def connect(self, scene):
self.scene = scene
talemate.emit.async_signals.get("game_loop_start").connect(self.on_game_loop_start)
talemate.emit.async_signals.get("game_loop_start").connect(
self.on_game_loop_start
)
def clean_result(self, result):
if "#" in result:
@@ -291,23 +298,28 @@ class Agent(ABC):
current_memory_context.append(memory)
return current_memory_context
# LLM client related methods. These are called during or after the client
# sends the prompt to the API.
def inject_prompt_paramters(self, prompt_param:dict, kind:str, agent_function_name:str):
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
"""
Injects prompt parameters before the client sends off the prompt
Override as needed.
"""
pass
def allow_repetition_break(self, kind:str, agent_function_name:str, auto:bool=False):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
"""
Returns True if repetition breaking is allowed, False otherwise.
"""
return False
@dataclasses.dataclass
class AgentEmission:
agent: Agent
agent: Agent

View File

@@ -1,3 +1,3 @@
"""
Code has been moved.
"""
"""

View File

@@ -1,6 +1,6 @@
from typing import Callable, TYPE_CHECKING
import contextvars
from typing import TYPE_CHECKING, Callable
import pydantic
__all__ = [
@@ -9,25 +9,26 @@ __all__ = [
active_agent = contextvars.ContextVar("active_agent", default=None)
class ActiveAgentContext(pydantic.BaseModel):
agent: object
fn: Callable
class Config:
arbitrary_types_allowed=True
arbitrary_types_allowed = True
@property
def action(self):
return self.fn.__name__
class ActiveAgent:
def __init__(self, agent, fn):
self.agent = ActiveAgentContext(agent=agent, fn=fn)
def __enter__(self):
self.token = active_agent.set(self.agent)
def __exit__(self, *args, **kwargs):
active_agent.reset(self.token)
return False

View File

@@ -1,40 +1,48 @@
from __future__ import annotations
import dataclasses
import re
import random
import re
from datetime import datetime
from typing import TYPE_CHECKING, Optional, Union
import structlog
import talemate.client as client
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
import structlog
from talemate.client.context import (
client_context_attribute,
set_client_context_attribute,
set_conversation_context_attribute,
)
from talemate.emit import emit
import talemate.emit.async_signals
from talemate.scene_message import CharacterMessage, DirectorMessage
from talemate.prompts import Prompt
from talemate.events import GameLoopEvent
from talemate.client.context import set_conversation_context_attribute, client_context_attribute, set_client_context_attribute
from talemate.prompts import Prompt
from talemate.scene_message import CharacterMessage, DirectorMessage
from .base import Agent, AgentEmission, set_processing, AgentAction, AgentActionConfig
from .base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
from .registry import register
if TYPE_CHECKING:
from talemate.tale_mate import Character, Scene, Actor
from talemate.tale_mate import Actor, Character, Scene
log = structlog.get_logger("talemate.agents.conversation")
@dataclasses.dataclass
class ConversationAgentEmission(AgentEmission):
actor: Actor
character: Character
generation: list[str]
talemate.emit.async_signals.register(
"agent.conversation.before_generate",
"agent.conversation.generated"
"agent.conversation.before_generate", "agent.conversation.generated"
)
@register()
class ConversationAgent(Agent):
"""
@@ -45,7 +53,7 @@ class ConversationAgent(Agent):
agent_type = "conversation"
verbose_name = "Conversation"
min_dialogue_length = 75
def __init__(
@@ -60,28 +68,28 @@ class ConversationAgent(Agent):
self.logging_enabled = logging_enabled
self.logging_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.current_memory_context = None
# several agents extend this class, but we only want to initialize
# these actions for the conversation agent
if self.agent_type != "conversation":
return
self.actions = {
"generation_override": AgentAction(
enabled = True,
label = "Generation Override",
description = "Override generation parameters",
config = {
enabled=True,
label="Generation Override",
description="Override generation parameters",
config={
"length": AgentActionConfig(
type="number",
label="Generation Length (tokens)",
description="Maximum number of tokens to generate for a conversation response.",
value=96,
value=96,
min=32,
max=512,
step=32,
),#
), #
"instructions": AgentActionConfig(
type="text",
label="Instructions",
@@ -96,24 +104,24 @@ class ConversationAgent(Agent):
min=0.0,
max=1.0,
step=0.1,
)
}
),
},
),
"auto_break_repetition": AgentAction(
enabled = True,
label = "Auto Break Repetition",
description = "Will attempt to automatically break AI repetition.",
enabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"natural_flow": AgentAction(
enabled = True,
label = "Natural Flow",
description = "Will attempt to generate a more natural flow of conversation between multiple characters.",
config = {
enabled=True,
label="Natural Flow",
description="Will attempt to generate a more natural flow of conversation between multiple characters.",
config={
"max_auto_turns": AgentActionConfig(
type="number",
label="Max. Auto Turns",
description="The maximum number of turns the AI is allowed to generate before it stops and waits for the player to respond.",
value=4,
value=4,
min=1,
max=100,
step=1,
@@ -122,31 +130,40 @@ class ConversationAgent(Agent):
type="number",
label="Max. Idle Turns",
description="The maximum number of turns a character can go without speaking before they are considered overdue to speak.",
value=8,
value=8,
min=1,
max=100,
step=1,
),
}
},
),
"use_long_term_memory": AgentAction(
enabled = True,
label = "Long Term Memory",
description = "Will augment the conversation prompt with long term memory.",
config = {
enabled=True,
label="Long Term Memory",
description="Will augment the conversation prompt with long term memory.",
config={
"retrieval_method": AgentActionConfig(
type="text",
label="Context Retrieval Method",
description="How relevant context is retrieved from the long term memory.",
value="direct",
choices=[
{"label": "Context queries based on recent dialogue (fast)", "value": "direct"},
{"label": "Context queries generated by AI", "value": "queries"},
{"label": "AI compiled question and answers (slow)", "value": "questions"},
]
{
"label": "Context queries based on recent dialogue (fast)",
"value": "direct",
},
{
"label": "Context queries generated by AI",
"value": "queries",
},
{
"label": "AI compiled question and answers (slow)",
"value": "questions",
},
],
),
}
),
},
),
}
def connect(self, scene):
@@ -154,40 +171,37 @@ class ConversationAgent(Agent):
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
def last_spoken(self):
"""
Returns the last time each character spoke
"""
last_turn = {}
turns = 0
character_names = self.scene.character_names
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
for idx in range(len(self.scene.history) - 1, -1, -1):
if isinstance(self.scene.history[idx], CharacterMessage):
if turns >= max_idle_turns:
break
character = self.scene.history[idx].character_name
if character in character_names:
last_turn[character] = turns
character_names.remove(character)
if not character_names:
break
turns += 1
if character_names and turns >= max_idle_turns:
for character in character_names:
last_turn[character] = max_idle_turns
last_turn[character] = max_idle_turns
return last_turn
def repeated_speaker(self):
"""
Counts the amount of times the most recent speaker has spoken in a row
@@ -203,125 +217,164 @@ class ConversationAgent(Agent):
else:
break
return count
async def on_game_loop(self, event:GameLoopEvent):
async def on_game_loop(self, event: GameLoopEvent):
await self.apply_natural_flow()
async def apply_natural_flow(self, force: bool = False, npcs_only: bool = False):
"""
If the natural flow action is enabled, this will attempt to determine
the ideal character to talk next.
This will let the AI pick a character to talk to, but if the AI can't figure
it out it will apply rules based on max_idle_turns and max_auto_turns.
If all fails it will just pick a random character.
Repetition is also taken into account, so if a character has spoken twice in a row
they will not be picked again until someone else has spoken.
"""
scene = self.scene
if not scene.auto_progress and not force:
# we only apply natural flow if auto_progress is enabled
return
if self.actions["natural_flow"].enabled and len(scene.character_names) > 2:
# last time each character spoke (turns ago)
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
max_auto_turns = self.actions["natural_flow"].config["max_auto_turns"].value
last_turn = self.last_spoken()
player_name = scene.get_player_character().name
last_turn_player = last_turn.get(player_name, 0)
if last_turn_player >= max_auto_turns and not npcs_only:
self.scene.next_actor = scene.get_player_character().name
log.debug("conversation_agent.natural_flow", next_actor="player", overdue=True, player_character=scene.get_player_character().name)
log.debug(
"conversation_agent.natural_flow",
next_actor="player",
overdue=True,
player_character=scene.get_player_character().name,
)
return
log.debug("conversation_agent.natural_flow", last_turn=last_turn)
# determine random character to talk, this will be the fallback in case
# the AI can't figure out who should talk next
if scene.prev_actor:
# we dont want to talk to the same person twice in a row
character_names = scene.character_names
character_names.remove(scene.prev_actor)
if npcs_only:
character_names = [c for c in character_names if c != player_name]
random_character_name = random.choice(character_names)
else:
character_names = scene.character_names
character_names = scene.character_names
# no one has talked yet, so we just pick a random character
if npcs_only:
character_names = [c for c in character_names if c != player_name]
random_character_name = random.choice(scene.character_names)
overdue_characters = [character for character, turn in last_turn.items() if turn >= max_idle_turns]
overdue_characters = [
character
for character, turn in last_turn.items()
if turn >= max_idle_turns
]
if npcs_only:
overdue_characters = [c for c in overdue_characters if c != player_name]
if overdue_characters and self.scene.history:
# Pick a random character from the overdue characters
scene.next_actor = random.choice(overdue_characters)
elif scene.history:
scene.next_actor = None
# AI will attempt to figure out who should talk next
next_actor = await self.select_talking_actor(character_names)
next_actor = next_actor.strip().strip('"').strip(".")
for character_name in scene.character_names:
if next_actor.lower() in character_name.lower() or character_name.lower() in next_actor.lower():
if (
next_actor.lower() in character_name.lower()
or character_name.lower() in next_actor.lower()
):
scene.next_actor = character_name
break
if not scene.next_actor:
# AI couldn't figure out who should talk next, so we just pick a random character
log.debug("conversation_agent.natural_flow", next_actor="random", random_character_name=random_character_name)
log.debug(
"conversation_agent.natural_flow",
next_actor="random",
random_character_name=random_character_name,
)
scene.next_actor = random_character_name
else:
log.debug("conversation_agent.natural_flow", next_actor="picked", ai_next_actor=scene.next_actor)
log.debug(
"conversation_agent.natural_flow",
next_actor="picked",
ai_next_actor=scene.next_actor,
)
else:
# always start with main character (TODO: configurable?)
player_character = scene.get_player_character()
log.debug("conversation_agent.natural_flow", next_actor="main_character", main_character=player_character)
scene.next_actor = player_character.name if player_character else random_character_name
scene.log.debug("conversation_agent.natural_flow", next_actor=scene.next_actor)
log.debug(
"conversation_agent.natural_flow",
next_actor="main_character",
main_character=player_character,
)
scene.next_actor = (
player_character.name if player_character else random_character_name
)
scene.log.debug(
"conversation_agent.natural_flow", next_actor=scene.next_actor
)
# same character cannot go thrice in a row, if this is happening, pick a random character that
# isnt the same as the last character
if self.repeated_speaker() >= 2 and self.scene.prev_actor == self.scene.next_actor:
scene.next_actor = random.choice([c for c in scene.character_names if c != scene.prev_actor])
scene.log.debug("conversation_agent.natural_flow", next_actor="random (repeated safeguard)", random_character_name=scene.next_actor)
if (
self.repeated_speaker() >= 2
and self.scene.prev_actor == self.scene.next_actor
):
scene.next_actor = random.choice(
[c for c in scene.character_names if c != scene.prev_actor]
)
scene.log.debug(
"conversation_agent.natural_flow",
next_actor="random (repeated safeguard)",
random_character_name=scene.next_actor,
)
else:
scene.next_actor = None
@set_processing
async def select_talking_actor(self, character_names: list[str]=None):
result = await Prompt.request("conversation.select-talking-actor", self.client, "conversation_select_talking_actor", vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character_names": character_names or self.scene.character_names,
"character_names_formatted": ", ".join(character_names or self.scene.character_names),
})
async def select_talking_actor(self, character_names: list[str] = None):
result = await Prompt.request(
"conversation.select-talking-actor",
self.client,
"conversation_select_talking_actor",
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character_names": character_names or self.scene.character_names,
"character_names_formatted": ", ".join(
character_names or self.scene.character_names
),
},
)
return result
async def build_prompt_default(
self,
@@ -335,17 +388,17 @@ class ConversationAgent(Agent):
# we subtract 200 to account for the response
scene = character.actor.scene
total_token_budget = self.client.max_token_length - 200
scene_and_dialogue_budget = total_token_budget - 500
long_term_memory_budget = min(int(total_token_budget * 0.05), 200)
scene_and_dialogue = scene.context_history(
budget=scene_and_dialogue_budget,
budget=scene_and_dialogue_budget,
keep_director=True,
sections=False,
)
memory = await self.build_prompt_default_memory(character)
main_character = scene.main_character.character
@@ -360,36 +413,39 @@ class ConversationAgent(Agent):
)
else:
formatted_names = character_names[0] if character_names else ""
try:
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
except IndexError:
director_message = False
extra_instructions = ""
if self.actions["generation_override"].enabled:
extra_instructions = self.actions["generation_override"].config["instructions"].value
extra_instructions = (
self.actions["generation_override"].config["instructions"].value
)
prompt = Prompt.get(
"conversation.dialogue",
vars={
"scene": scene,
"max_tokens": self.client.max_token_length,
"scene_and_dialogue_budget": scene_and_dialogue_budget,
"scene_and_dialogue": scene_and_dialogue,
"memory": memory,
"characters": list(scene.get_characters()),
"main_character": main_character,
"formatted_names": formatted_names,
"talking_character": character,
"partial_message": char_message,
"director_message": director_message,
"extra_instructions": extra_instructions,
},
)
prompt = Prompt.get("conversation.dialogue", vars={
"scene": scene,
"max_tokens": self.client.max_token_length,
"scene_and_dialogue_budget": scene_and_dialogue_budget,
"scene_and_dialogue": scene_and_dialogue,
"memory": memory,
"characters": list(scene.get_characters()),
"main_character": main_character,
"formatted_names": formatted_names,
"talking_character": character,
"partial_message": char_message,
"director_message": director_message,
"extra_instructions": extra_instructions,
})
return str(prompt)
async def build_prompt_default_memory(
self, character: Character
):
async def build_prompt_default_memory(self, character: Character):
"""
Builds long term memory for the conversation prompt
@@ -404,39 +460,56 @@ class ConversationAgent(Agent):
if not self.actions["use_long_term_memory"].enabled:
return []
if self.current_memory_context:
return self.current_memory_context
self.current_memory_context = ""
retrieval_method = self.actions["use_long_term_memory"].config["retrieval_method"].value
retrieval_method = (
self.actions["use_long_term_memory"].config["retrieval_method"].value
)
if retrieval_method != "direct":
world_state = instance.get_agent("world_state")
history = self.scene.context_history(min_dialogue=3, max_dialogue=15, keep_director=False, sections=False, add_archieved_history=False)
history = self.scene.context_history(
min_dialogue=3,
max_dialogue=15,
keep_director=False,
sections=False,
add_archieved_history=False,
)
text = "\n".join(history)
log.debug("conversation_agent.build_prompt_default_memory", direct=False, version=retrieval_method)
log.debug(
"conversation_agent.build_prompt_default_memory",
direct=False,
version=retrieval_method,
)
if retrieval_method == "questions":
self.current_memory_context = (await world_state.analyze_text_and_extract_context(
text, f"continue the conversation as {character.name}"
)).split("\n")
self.current_memory_context = (
await world_state.analyze_text_and_extract_context(
text, f"continue the conversation as {character.name}"
)
).split("\n")
elif retrieval_method == "queries":
self.current_memory_context = await world_state.analyze_text_and_extract_context_via_queries(
text, f"continue the conversation as {character.name}"
self.current_memory_context = (
await world_state.analyze_text_and_extract_context_via_queries(
text, f"continue the conversation as {character.name}"
)
)
else:
history = list(map(str, self.scene.collect_messages(max_iterations=3)))
log.debug("conversation_agent.build_prompt_default_memory", history=history, direct=True)
log.debug(
"conversation_agent.build_prompt_default_memory",
history=history,
direct=True,
)
memory = instance.get_agent("memory")
context = await memory.multi_query(history, max_tokens=500, iterate=5)
self.current_memory_context = context
return self.current_memory_context
async def build_prompt(self, character, char_message: str = ""):
@@ -445,10 +518,9 @@ class ConversationAgent(Agent):
return await fn(character, char_message=char_message)
def clean_result(self, result, character):
if "#" in result:
result = result.split("#")[0]
result = result.replace(" :", ":")
result = result.replace("[", "*").replace("]", "*")
result = result.replace("(", "*").replace(")", "*")
@@ -459,15 +531,19 @@ class ConversationAgent(Agent):
def set_generation_overrides(self):
if not self.actions["generation_override"].enabled:
return
set_conversation_context_attribute("length", self.actions["generation_override"].config["length"].value)
set_conversation_context_attribute(
"length", self.actions["generation_override"].config["length"].value
)
if self.actions["generation_override"].config["jiggle"].value > 0.0:
nuke_repetition = client_context_attribute("nuke_repetition")
if nuke_repetition == 0.0:
# we only apply the agent override if some other mechanism isn't already
# setting the nuke_repetition value
nuke_repetition = self.actions["generation_override"].config["jiggle"].value
nuke_repetition = (
self.actions["generation_override"].config["jiggle"].value
)
set_client_context_attribute("nuke_repetition", nuke_repetition)
@set_processing
@@ -479,10 +555,14 @@ class ConversationAgent(Agent):
self.current_memory_context = None
character = actor.character
emission = ConversationAgentEmission(agent=self, generation="", actor=actor, character=character)
await talemate.emit.async_signals.get("agent.conversation.before_generate").send(emission)
emission = ConversationAgentEmission(
agent=self, generation="", actor=actor, character=character
)
await talemate.emit.async_signals.get(
"agent.conversation.before_generate"
).send(emission)
self.set_generation_overrides()
result = await self.client.send_prompt(await self.build_prompt(character))
@@ -505,7 +585,7 @@ class ConversationAgent(Agent):
result = self.clean_result(result, character)
total_result += " "+result
total_result += " " + result
if len(total_result) == 0 and max_loops < 10:
max_loops += 1
@@ -529,7 +609,7 @@ class ConversationAgent(Agent):
# Removes partial sentence at the end
total_result = util.clean_dialogue(total_result, main_name=character.name)
# Remove "{character.name}:" - all occurences
total_result = total_result.replace(f"{character.name}:", "")
@@ -548,13 +628,17 @@ class ConversationAgent(Agent):
)
response_message = util.parse_messages_from_str(total_result, [character.name])
log.info("conversation agent", result=response_message)
emission = ConversationAgentEmission(agent=self, generation=response_message, actor=actor, character=character)
await talemate.emit.async_signals.get("agent.conversation.generated").send(emission)
#log.info("conversation agent", generation=emission.generation)
log.info("conversation agent", result=response_message)
emission = ConversationAgentEmission(
agent=self, generation=response_message, actor=actor, character=character
)
await talemate.emit.async_signals.get("agent.conversation.generated").send(
emission
)
# log.info("conversation agent", generation=emission.generation)
messages = [CharacterMessage(message) for message in emission.generation]
@@ -563,15 +647,17 @@ class ConversationAgent(Agent):
return messages
def allow_repetition_break(self, kind: str, agent_function_name: str, auto: bool = False):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return agent_function_name == "converse"
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += ['[']
prompt_param["extra_stopping_strings"] += ["["]

View File

@@ -3,22 +3,23 @@ from __future__ import annotations
import json
import os
import talemate.client as client
from talemate.agents.base import Agent, set_processing
from talemate.agents.registry import register
from talemate.emit import emit
from talemate.prompts import Prompt
import talemate.client as client
from .character import CharacterCreatorMixin
from .scenario import ScenarioCreatorMixin
@register()
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
"""
Creates characters and scenarios and other fun stuff!
"""
agent_type = "creator"
verbose_name = "Creator"
@@ -78,12 +79,14 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
# Remove duplicates while preserving the order for list type keys
for key, value in merged_data.items():
if isinstance(value, list):
merged_data[key] = [x for i, x in enumerate(value) if x not in value[:i]]
merged_data[key] = [
x for i, x in enumerate(value) if x not in value[:i]
]
merged_data["context"] = context
return merged_data
def load_templates_old(self, names: list, template_type: str = "character") -> dict:
"""
Loads multiple character creation templates from ./templates/character and merges them in order.
@@ -128,8 +131,10 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
if "context" in template_data["instructions"]:
context = template_data["instructions"]["context"]
merged_instructions[name]["questions"] = [q[0] for q in template_data.get("questions", [])]
merged_instructions[name]["questions"] = [
q[0] for q in template_data.get("questions", [])
]
# Remove duplicates while preserving the order
merged_template = [
@@ -158,24 +163,33 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
return rv
@set_processing
async def generate_json_list(
self,
text:str,
count:int=20,
first_item:str=None,
text: str,
count: int = 20,
first_item: str = None,
):
_, json_list = await Prompt.request(f"creator.generate-json-list", self.client, "create", vars={
"text": text,
"first_item": first_item,
"count": count,
})
return json_list.get("items",[])
_, json_list = await Prompt.request(
f"creator.generate-json-list",
self.client,
"create",
vars={
"text": text,
"first_item": first_item,
"count": count,
},
)
return json_list.get("items", [])
@set_processing
async def generate_title(self, text:str):
title = await Prompt.request(f"creator.generate-title", self.client, "create_short", vars={
"text": text,
})
return title
async def generate_title(self, text: str):
title = await Prompt.request(
f"creator.generate-title",
self.client,
"create_short",
vars={
"text": text,
},
)
return title

View File

@@ -1,42 +1,48 @@
from __future__ import annotations
import re
import asyncio
import random
import structlog
import re
from typing import TYPE_CHECKING, Callable
import structlog
import talemate.util as util
from talemate.emit import emit
from talemate.prompts import Prompt, LoopedPrompt
from talemate.exceptions import LLMAccuracyError
from talemate.agents.base import set_processing
from talemate.emit import emit
from talemate.exceptions import LLMAccuracyError
from talemate.prompts import LoopedPrompt, Prompt
if TYPE_CHECKING:
from talemate.tale_mate import Character
log = structlog.get_logger("talemate.agents.creator.character")
def validate(k,v):
def validate(k, v):
if k and k.lower() == "gender":
return v.lower().strip()
if k and k.lower() == "age":
try:
return int(v.split("\n")[0].strip())
except (ValueError, TypeError):
raise LLMAccuracyError("Was unable to get a valid age from the response", model_name=None)
raise LLMAccuracyError(
"Was unable to get a valid age from the response", model_name=None
)
return v.strip().strip("\n")
DEFAULT_CONTENT_CONTEXT="a fun and engaging adventure aimed at an adult audience."
DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audience."
class CharacterCreatorMixin:
"""
Adds character creation functionality to the creator agent
"""
## NEW
@set_processing
async def create_character_attributes(
self,
@@ -48,8 +54,6 @@ class CharacterCreatorMixin:
custom_attributes: dict[str, str] = dict(),
predefined_attributes: dict[str, str] = dict(),
):
def spice(prompt, spices):
# generate number from 0 to 1 and if its smaller than use_spice
# select a random spice from the list and return it formatted
@@ -57,69 +61,74 @@ class CharacterCreatorMixin:
if random.random() < use_spice:
spice = random.choice(spices)
return prompt.format(spice=spice)
return ""
return ""
# drop any empty attributes from predefined_attributes
predefined_attributes = {k:v for k,v in predefined_attributes.items() if v}
prompt = Prompt.get(f"creator.character-attributes-{template}", vars={
"character_prompt": character_prompt,
"template": template,
"spice": spice,
"content_context": content_context,
"custom_attributes": custom_attributes,
"character_sheet": LoopedPrompt(
validate_value=validate,
on_update=attribute_callback,
generated=predefined_attributes,
),
})
predefined_attributes = {k: v for k, v in predefined_attributes.items() if v}
prompt = Prompt.get(
f"creator.character-attributes-{template}",
vars={
"character_prompt": character_prompt,
"template": template,
"spice": spice,
"content_context": content_context,
"custom_attributes": custom_attributes,
"character_sheet": LoopedPrompt(
validate_value=validate,
on_update=attribute_callback,
generated=predefined_attributes,
),
},
)
await prompt.loop(self.client, "character_sheet", kind="create_concise")
return prompt.vars["character_sheet"].generated
@set_processing
async def create_character_description(
self,
character:Character,
self,
character: Character,
content_context: str = DEFAULT_CONTENT_CONTEXT,
):
description = await Prompt.request(f"creator.character-description", self.client, "create", vars={
"character": character,
"content_context": content_context,
})
description = await Prompt.request(
f"creator.character-description",
self.client,
"create",
vars={
"character": character,
"content_context": content_context,
},
)
return description.strip()
@set_processing
async def create_character_details(
self,
self,
character: Character,
template: str,
detail_callback: Callable = lambda question, answer: None,
questions: list[str] = None,
content_context: str = DEFAULT_CONTENT_CONTEXT,
):
prompt = Prompt.get(f"creator.character-details-{template}", vars={
"character_details": LoopedPrompt(
validate_value=validate,
on_update=detail_callback,
),
"template": template,
"content_context": content_context,
"character": character,
"custom_questions": questions or [],
})
prompt = Prompt.get(
f"creator.character-details-{template}",
vars={
"character_details": LoopedPrompt(
validate_value=validate,
on_update=detail_callback,
),
"template": template,
"content_context": content_context,
"character": character,
"custom_questions": questions or [],
},
)
await prompt.loop(self.client, "character_details", kind="create_concise")
return prompt.vars["character_details"].generated
@set_processing
async def create_character_example_dialogue(
self,
@@ -131,97 +140,116 @@ class CharacterCreatorMixin:
example_callback: Callable = lambda example: None,
rules_callback: Callable = lambda rules: None,
):
dialogue_rules = await Prompt.request(f"creator.character-dialogue-rules", self.client, "create", vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
})
dialogue_rules = await Prompt.request(
f"creator.character-dialogue-rules",
self.client,
"create",
vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
},
)
log.info("dialogue_rules", dialogue_rules=dialogue_rules)
if rules_callback:
rules_callback(dialogue_rules)
example_dialogue_prompt = Prompt.get(f"creator.character-example-dialogue-{template}", vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
"dialogue_rules": dialogue_rules,
"generated_examples": LoopedPrompt(
validate_value=validate,
on_update=example_callback,
),
})
await example_dialogue_prompt.loop(self.client, "generated_examples", kind="create")
example_dialogue_prompt = Prompt.get(
f"creator.character-example-dialogue-{template}",
vars={
"guide": guide,
"character": character,
"examples": examples or [],
"content_context": content_context,
"dialogue_rules": dialogue_rules,
"generated_examples": LoopedPrompt(
validate_value=validate,
on_update=example_callback,
),
},
)
await example_dialogue_prompt.loop(
self.client, "generated_examples", kind="create"
)
return example_dialogue_prompt.vars["generated_examples"].generated
@set_processing
async def determine_content_context_for_character(
self,
character: Character,
):
content_context = await Prompt.request(f"creator.determine-content-context", self.client, "create", vars={
"character": character,
})
content_context = await Prompt.request(
f"creator.determine-content-context",
self.client,
"create",
vars={
"character": character,
},
)
return content_context.strip()
@set_processing
async def determine_character_attributes(
self,
character: Character,
):
attributes = await Prompt.request(f"creator.determine-character-attributes", self.client, "analyze_long", vars={
"character": character,
})
attributes = await Prompt.request(
f"creator.determine-character-attributes",
self.client,
"analyze_long",
vars={
"character": character,
},
)
return attributes
@set_processing
async def determine_character_description(
self,
character: Character,
text:str=""
self, character: Character, text: str = ""
):
description = await Prompt.request(f"creator.determine-character-description", self.client, "create", vars={
"character": character,
"scene": self.scene,
"text": text,
"max_tokens": self.client.max_token_length,
})
description = await Prompt.request(
f"creator.determine-character-description",
self.client,
"create",
vars={
"character": character,
"scene": self.scene,
"text": text,
"max_tokens": self.client.max_token_length,
},
)
return description.strip()
@set_processing
async def determine_character_goals(
self,
character: Character,
goal_instructions: str,
):
goals = await Prompt.request(f"creator.determine-character-goals", self.client, "create", vars={
"character": character,
"scene": self.scene,
"goal_instructions": goal_instructions,
"npc_name": character.name,
"player_name": self.scene.get_player_character().name,
"max_tokens": self.client.max_token_length,
})
goals = await Prompt.request(
f"creator.determine-character-goals",
self.client,
"create",
vars={
"character": character,
"scene": self.scene,
"goal_instructions": goal_instructions,
"npc_name": character.name,
"player_name": self.scene.get_player_character().name,
"max_tokens": self.client.max_token_length,
},
)
log.debug("determine_character_goals", goals=goals, character=character)
await character.set_detail("goals", goals.strip())
return goals.strip()
@set_processing
async def generate_character_from_text(
self,
@@ -229,11 +257,8 @@ class CharacterCreatorMixin:
template: str,
content_context: str = DEFAULT_CONTENT_CONTEXT,
):
base_attributes = await self.create_character_attributes(
character_prompt=text,
template=template,
content_context=content_context,
)

View File

@@ -1,36 +1,36 @@
from talemate.emit import emit, wait_for_input_yesno
import re
import random
import re
from talemate.prompts import Prompt
from talemate.agents.base import set_processing
from talemate.emit import emit, wait_for_input_yesno
from talemate.prompts import Prompt
class ScenarioCreatorMixin:
"""
Adds scenario creation functionality to the creator agent
"""
@set_processing
async def create_scene_description(
self,
prompt:str,
content_context:str,
prompt: str,
content_context: str,
):
"""
Creates a new scene.
Arguments:
prompt (str): The prompt to use to create the scene.
content_context (str): The content context to use for the scene.
callback (callable): A callback to call when the scene has been created.
"""
scene = self.scene
description = await Prompt.request(
"creator.scenario-description",
self.client,
@@ -40,35 +40,32 @@ class ScenarioCreatorMixin:
"content_context": content_context,
"max_tokens": self.client.max_token_length,
"scene": scene,
}
},
)
description = description.strip()
return description
@set_processing
async def create_scene_name(
self,
prompt:str,
content_context:str,
description:str,
prompt: str,
content_context: str,
description: str,
):
"""
Generates a scene name.
Arguments:
prompt (str): The prompt to use to generate the scene name.
content_context (str): The content context to use for the scene.
description (str): The description of the scene.
"""
scene = self.scene
name = await Prompt.request(
"creator.scenario-name",
self.client,
@@ -78,37 +75,35 @@ class ScenarioCreatorMixin:
"content_context": content_context,
"description": description,
"scene": scene,
}
},
)
name = name.strip().strip('.!').replace('"','')
name = name.strip().strip(".!").replace('"', "")
return name
@set_processing
async def create_scene_intro(
self,
prompt:str,
content_context:str,
description:str,
name:str,
prompt: str,
content_context: str,
description: str,
name: str,
):
"""
Generates a scene introduction.
Arguments:
prompt (str): The prompt to use to generate the scene introduction.
content_context (str): The content context to use for the scene.
description (str): The description of the scene.
name (str): The name of the scene.
"""
scene = self.scene
intro = await Prompt.request(
"creator.scenario-intro",
self.client,
@@ -119,17 +114,19 @@ class ScenarioCreatorMixin:
"description": description,
"name": name,
"scene": scene,
}
},
)
intro = intro.strip()
return intro
@set_processing
async def determine_scenario_description(
self,
text:str
):
description = await Prompt.request(f"creator.determine-scenario-description", self.client, "analyze_long", vars={
"text": text,
})
return description
async def determine_scenario_description(self, text: str):
description = await Prompt.request(
f"creator.determine-scenario-description",
self.client,
"analyze_long",
vars={
"text": text,
},
)
return description

View File

@@ -1,200 +1,251 @@
from __future__ import annotations
import asyncio
import re
import random
import structlog
import re
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import talemate.util as util
from talemate.emit import wait_for_input, emit
import talemate.emit.async_signals
from talemate.prompts import Prompt
from talemate.scene_message import NarratorMessage, DirectorMessage
from talemate.automated_action import AutomatedAction
import structlog
import talemate.automated_action as automated_action
from talemate.agents.conversation import ConversationAgentEmission
from .registry import register
from .base import set_processing, AgentAction, AgentActionConfig, Agent
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
from talemate.agents.conversation import ConversationAgentEmission
from talemate.automated_action import AutomatedAction
from talemate.emit import emit, wait_for_input
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, NarratorMessage
from .base import Agent, AgentAction, AgentActionConfig, set_processing
from .registry import register
if TYPE_CHECKING:
from talemate import Actor, Character, Player, Scene
log = structlog.get_logger("talemate.agent.director")
@register()
class DirectorAgent(Agent):
agent_type = "director"
verbose_name = "Director"
def __init__(self, client, **kwargs):
self.is_enabled = True
self.client = client
self.next_direct_character = {}
self.next_direct_scene = 0
self.actions = {
"direct": AgentAction(enabled=True, label="Direct", description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).", config={
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before directing the sceen", value=5, min=1, max=100, step=1),
"direct_scene": AgentActionConfig(type="bool", label="Direct Scene", description="If enabled, the scene will be directed through narration", value=True),
"direct_actors": AgentActionConfig(type="bool", label="Direct Actors", description="If enabled, direction will be given to actors based on their goals.", value=True),
}),
"direct": AgentAction(
enabled=True,
label="Direct",
description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).",
config={
"turns": AgentActionConfig(
type="number",
label="Turns",
description="Number of turns to wait before directing the sceen",
value=5,
min=1,
max=100,
step=1,
),
"direct_scene": AgentActionConfig(
type="bool",
label="Direct Scene",
description="If enabled, the scene will be directed through narration",
value=True,
),
"direct_actors": AgentActionConfig(
type="bool",
label="Direct Actors",
description="If enabled, direction will be given to actors based on their goals.",
value=True,
),
},
),
}
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return True
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(self.on_conversation_before_generate)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_player_dialog)
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(
self.on_conversation_before_generate
)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
self.on_player_dialog
)
talemate.emit.async_signals.get("scene_init").connect(self.on_scene_init)
async def on_scene_init(self, event: SceneStateEvent):
"""
If game state instructions specify to be run at the start of the game loop
we will run them here.
"""
if not self.enabled:
if self.scene.game_state.has_scene_instructions:
self.is_enabled = True
log.warning("on_scene_init - enabling director", scene=self.scene)
self.is_enabled = True
log.warning("on_scene_init - enabling director", scene=self.scene)
else:
return
if not self.scene.game_state.has_scene_instructions:
return
if not self.scene.game_state.ops.run_on_start:
return
log.info("on_game_loop_start - running game state instructions")
await self.run_gamestate_instructions()
async def on_conversation_before_generate(self, event:ConversationAgentEmission):
async def on_conversation_before_generate(self, event: ConversationAgentEmission):
log.info("on_conversation_before_generate", director_enabled=self.enabled)
if not self.enabled:
return
await self.direct(event.character)
async def on_player_dialog(self, event:GameLoopActorIterEvent):
async def on_player_dialog(self, event: GameLoopActorIterEvent):
if not self.enabled:
return
if not self.scene.game_state.has_scene_instructions:
return
if not event.actor.character.is_player:
return
if event.game_loop.had_passive_narration:
log.debug("director.on_player_dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
log.debug(
"director.on_player_dialog",
skip=True,
had_passive_narration=event.game_loop.had_passive_narration,
)
return
event.game_loop.had_passive_narration = await self.direct(None)
async def direct(self, character: Character) -> bool:
if not self.actions["direct"].enabled:
return False
if character:
if not self.actions["direct"].config["direct_actors"].value:
log.info("direct", skip=True, reason="direct_actors disabled", character=character)
log.info(
"direct",
skip=True,
reason="direct_actors disabled",
character=character,
)
return False
# character direction, see if there are character goals
# character direction, see if there are character goals
# defined
character_goals = character.get_detail("goals")
if not character_goals:
log.info("direct", skip=True, reason="no goals", character=character)
return False
next_direct = self.next_direct_character.get(character.name, 0)
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
log.info("direct", skip=True, next_direct=next_direct, character=character)
if (
next_direct % self.actions["direct"].config["turns"].value != 0
or next_direct == 0
):
log.info(
"direct", skip=True, next_direct=next_direct, character=character
)
self.next_direct_character[character.name] = next_direct + 1
return False
self.next_direct_character[character.name] = 0
await self.direct_scene(character, character_goals)
return True
else:
if not self.actions["direct"].config["direct_scene"].value:
log.info("direct", skip=True, reason="direct_scene disabled")
return False
# no character, see if there are NPC characters at all
# if not we always want to direct narration
always_direct = (not self.scene.npc_character_names)
always_direct = not self.scene.npc_character_names
next_direct = self.next_direct_scene
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
if (
next_direct % self.actions["direct"].config["turns"].value != 0
or next_direct == 0
):
if not always_direct:
log.info("direct", skip=True, next_direct=next_direct)
self.next_direct_scene += 1
return False
self.next_direct_scene = 0
await self.direct_scene(None, None)
return True
@set_processing
async def run_gamestate_instructions(self):
"""
Run game state instructions, if they exist.
"""
if not self.scene.game_state.has_scene_instructions:
return
await self.direct_scene(None, None)
@set_processing
async def direct_scene(self, character: Character, prompt:str):
async def direct_scene(self, character: Character, prompt: str):
if not character and self.scene.game_state.game_won:
# we are not directing a character, and the game has been won
# so we don't need to direct the scene any further
return
if character:
# direct a character
response = await Prompt.request("director.direct-character", self.client, "director", vars={
"max_tokens": self.client.max_token_length,
"scene": self.scene,
"prompt": prompt,
"character": character,
"player_character": self.scene.get_player_character(),
"game_state": self.scene.game_state,
})
response = await Prompt.request(
"director.direct-character",
self.client,
"director",
vars={
"max_tokens": self.client.max_token_length,
"scene": self.scene,
"prompt": prompt,
"character": character,
"player_character": self.scene.get_player_character(),
"game_state": self.scene.game_state,
},
)
if "#" in response:
response = response.split("#")[0]
log.info("direct_character", character=character, prompt=prompt, response=response)
log.info(
"direct_character",
character=character,
prompt=prompt,
response=response,
)
response = response.strip().split("\n")[0].strip()
#response += f" (current story goal: {prompt})"
# response += f" (current story goal: {prompt})"
message = DirectorMessage(response, source=character.name)
emit("director", message, character=character)
self.scene.push_history(message)
@@ -204,24 +255,38 @@ class DirectorAgent(Agent):
@set_processing
async def persist_character(
self,
name:str,
content:str = None,
attributes:str = None,
self,
name: str,
content: str = None,
attributes: str = None,
):
world_state = instance.get_agent("world_state")
creator = instance.get_agent("creator")
self.scene.log.debug("persist_character", name=name)
character = self.scene.Character(name=name)
character.color = random.choice(['#F08080', '#FFD700', '#90EE90', '#ADD8E6', '#DDA0DD', '#FFB6C1', '#FAFAD2', '#D3D3D3', '#B0E0E6', '#FFDEAD'])
character.color = random.choice(
[
"#F08080",
"#FFD700",
"#90EE90",
"#ADD8E6",
"#DDA0DD",
"#FFB6C1",
"#FAFAD2",
"#D3D3D3",
"#B0E0E6",
"#FFDEAD",
]
)
if not attributes:
attributes = await world_state.extract_character_sheet(name=name, text=content)
attributes = await world_state.extract_character_sheet(
name=name, text=content
)
else:
attributes = world_state._parse_character_sheet(attributes)
self.scene.log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
@@ -232,35 +297,54 @@ class DirectorAgent(Agent):
self.scene.log.debug("persist_character", description=description)
actor = self.scene.Actor(character=character, agent=instance.get_agent("conversation"))
actor = self.scene.Actor(
character=character, agent=instance.get_agent("conversation")
)
await self.scene.add_actor(actor)
self.scene.emit_status()
return character
@set_processing
async def update_content_context(self, content:str=None, extra_choices:list[str]=None):
async def update_content_context(
self, content: str = None, extra_choices: list[str] = None
):
if not content:
content = "\n".join(self.scene.context_history(sections=False, min_dialogue=25, budget=2048))
response = await Prompt.request("world_state.determine-content-context", self.client, "analyze_freeform", vars={
"content": content,
"extra_choices": extra_choices or [],
})
content = "\n".join(
self.scene.context_history(sections=False, min_dialogue=25, budget=2048)
)
response = await Prompt.request(
"world_state.determine-content-context",
self.client,
"analyze_freeform",
vars={
"content": content,
"extra_choices": extra_choices or [],
},
)
self.scene.context = response.strip()
self.scene.emit_status()
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
log.debug(
"inject_prompt_paramters",
prompt_param=prompt_param,
kind=kind,
agent_function_name=agent_function_name,
)
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += character_names + ["#"]
if agent_function_name == "update_content_context":
prompt_param["extra_stopping_strings"] += ["\n"]
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
return True
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
return True

View File

@@ -1,30 +1,30 @@
from __future__ import annotations
import asyncio
import re
import time
import traceback
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
import talemate.data_objects as data_objects
import talemate.util as util
import talemate.emit.async_signals
import talemate.util as util
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage
from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .base import Agent, AgentAction, AgentActionConfig, set_processing
from .registry import register
import structlog
import time
import re
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Character, Scene
from talemate.agents.conversation import ConversationAgentEmission
from talemate.agents.narrator import NarratorAgentEmission
from talemate.tale_mate import Actor, Character, Scene
log = structlog.get_logger("talemate.agents.editor")
@register()
class EditorAgent(Agent):
"""
@@ -35,175 +35,195 @@ class EditorAgent(Agent):
agent_type = "editor"
verbose_name = "Editor"
def __init__(self, client, **kwargs):
self.client = client
self.is_enabled = True
self.actions = {
"edit_dialogue": AgentAction(enabled=False, label="Edit dialogue", description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue."),
"fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.", config={
"narrator": AgentActionConfig(type="bool", label="Fix narrator messages", description="Will attempt to fix exposition issues in narrator messages", value=True),
}),
"add_detail": AgentAction(enabled=False, label="Add detail", description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.")
"edit_dialogue": AgentAction(
enabled=False,
label="Edit dialogue",
description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue.",
),
"fix_exposition": AgentAction(
enabled=True,
label="Fix exposition",
description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.",
config={
"narrator": AgentActionConfig(
type="bool",
label="Fix narrator messages",
description="Will attempt to fix exposition issues in narrator messages",
value=True,
),
},
),
"add_detail": AgentAction(
enabled=False,
label="Add detail",
description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
),
}
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return True
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
talemate.emit.async_signals.get("agent.narrator.generated").connect(self.on_narrator_generated)
async def on_conversation_generated(self, emission:ConversationAgentEmission):
talemate.emit.async_signals.get("agent.conversation.generated").connect(
self.on_conversation_generated
)
talemate.emit.async_signals.get("agent.narrator.generated").connect(
self.on_narrator_generated
)
async def on_conversation_generated(self, emission: ConversationAgentEmission):
"""
Called when a conversation is generated
"""
if not self.enabled:
return
log.info("editing conversation", emission=emission)
edited = []
for text in emission.generation:
edit = await self.add_detail(text, emission.character)
edit = await self.edit_conversation(edit, emission.character)
edit = await self.fix_exposition(edit, emission.character)
edit = await self.add_detail(
text,
emission.character
)
edit = await self.edit_conversation(
edit,
emission.character
)
edit = await self.fix_exposition(
edit,
emission.character
)
edited.append(edit)
emission.generation = edited
async def on_narrator_generated(self, emission:NarratorAgentEmission):
async def on_narrator_generated(self, emission: NarratorAgentEmission):
"""
Called when a narrator message is generated
"""
if not self.enabled:
return
log.info("editing narrator", emission=emission)
edited = []
for text in emission.generation:
edit = await self.fix_exposition_on_narrator(text)
edited.append(edit)
emission.generation = edited
@set_processing
async def edit_conversation(self, content:str, character:Character):
async def edit_conversation(self, content: str, character: Character):
"""
Edits a conversation
"""
if not self.actions["edit_dialogue"].enabled:
return content
response = await Prompt.request("editor.edit-dialogue", self.client, "edit_dialogue", vars={
"content": content,
"character": character,
"scene": self.scene,
"max_length": self.client.max_token_length
})
response = await Prompt.request(
"editor.edit-dialogue",
self.client,
"edit_dialogue",
vars={
"content": content,
"character": character,
"scene": self.scene,
"max_length": self.client.max_token_length,
},
)
response = response.split("[end]")[0]
response = util.replace_exposition_markers(response)
response = util.clean_dialogue(response, main_name=character.name)
response = util.clean_dialogue(response, main_name=character.name)
response = util.strip_partial_sentences(response)
return response
@set_processing
async def fix_exposition(self, content:str, character:Character):
async def fix_exposition(self, content: str, character: Character):
"""
Edits a text to make sure all narrative exposition and emotes is encased in *
"""
if not self.actions["fix_exposition"].enabled:
return content
if not character.is_player:
if '"' not in content and '*' not in content:
if '"' not in content and "*" not in content:
content = util.strip_partial_sentences(content)
character_prefix = f"{character.name}: "
message = content.split(character_prefix)[1]
content = f"{character_prefix}*{message.strip('*')}*"
return content
elif '"' in content:
# silly hack to clean up some LLMs that always start with a quote
# even though the immediate next thing is a narration (indicated by *)
content = content.replace(f"{character.name}: \"*", f"{character.name}: *")
content = util.clean_dialogue(content, main_name=character.name)
content = content.replace(
f'{character.name}: "*', f"{character.name}: *"
)
content = util.clean_dialogue(content, main_name=character.name)
content = util.strip_partial_sentences(content)
content = util.ensure_dialog_format(content, talking_character=character.name)
return content
@set_processing
async def fix_exposition_on_narrator(self, content:str):
async def fix_exposition_on_narrator(self, content: str):
if not self.actions["fix_exposition"].enabled:
return content
if not self.actions["fix_exposition"].config["narrator"].value:
return content
content = util.strip_partial_sentences(content)
if '"' not in content:
content = f"*{content.strip('*')}*"
else:
content = util.ensure_dialog_format(content)
return content
@set_processing
async def add_detail(self, content:str, character:Character):
async def add_detail(self, content: str, character: Character):
"""
Edits a text to increase its length and add extra detail and exposition
"""
if not self.actions["add_detail"].enabled:
return content
response = await Prompt.request("editor.add-detail", self.client, "edit_add_detail", vars={
"content": content,
"character": character,
"scene": self.scene,
"max_length": self.client.max_token_length
})
response = await Prompt.request(
"editor.add-detail",
self.client,
"edit_add_detail",
vars={
"content": content,
"character": character,
"scene": self.scene,
"max_length": self.client.max_token_length,
},
)
response = util.replace_exposition_markers(response)
response = util.clean_dialogue(response, main_name=character.name)
response = util.clean_dialogue(response, main_name=character.name)
response = util.strip_partial_sentences(response)
return response
return response

View File

@@ -1,19 +1,21 @@
from __future__ import annotations
import asyncio
import functools
import os
import shutil
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
from chromadb.config import Settings
import talemate.events as events
import talemate.util as util
from talemate.agents.base import set_processing
from talemate.config import load_config
from talemate.context import scene_is_loading
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.context import scene_is_loading
from talemate.config import load_config
from talemate.agents.base import set_processing
import structlog
import shutil
import functools
try:
import chromadb
@@ -30,17 +32,18 @@ if not chromadb:
from .base import Agent
class MemoryDocument(str):
def __new__(cls, text, meta, id, raw):
inst = super().__new__(cls, text)
inst.meta = meta
inst.id = id
inst.raw = raw
return inst
class MemoryAgent(Agent):
"""
An agent that can be used to maintain and access a memory of the world
@@ -52,10 +55,11 @@ class MemoryAgent(Agent):
@property
def readonly(self):
if scene_is_loading.get() and not getattr(self.scene, "_memory_never_persisted", False):
if scene_is_loading.get() and not getattr(
self.scene, "_memory_never_persisted", False
):
return True
return False
@property
@@ -72,9 +76,9 @@ class MemoryAgent(Agent):
self.memory_tracker = {}
self.config = load_config()
self._ready_to_add = False
handlers["config_saved"].connect(self.on_config_saved)
def on_config_saved(self, event):
openai_key = self.openai_api_key
self.config = load_config()
@@ -92,35 +96,68 @@ class MemoryAgent(Agent):
raise NotImplementedError()
@set_processing
async def add(self, text, character=None, uid=None, ts:str=None, **kwargs):
async def add(self, text, character=None, uid=None, ts: str = None, **kwargs):
if not text:
return
if self.readonly:
log.debug("memory agent", status="readonly")
return
while not self._ready_to_add:
await asyncio.sleep(0.1)
log.debug("memory agent add", text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
log.debug(
"memory agent add",
text=text[:50],
character=character,
uid=uid,
ts=ts,
**kwargs,
)
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
await loop.run_in_executor(
None,
functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs),
)
except AttributeError as e:
# not sure how this sometimes happens.
# chromadb model None
# race condition because we are forcing async context onto it?
log.error("memory agent", error="failed to add memory", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
log.error(
"memory agent",
error="failed to add memory",
details=e,
text=text[:50],
character=character,
uid=uid,
ts=ts,
**kwargs,
)
await asyncio.sleep(1.0)
try:
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
await loop.run_in_executor(
None,
functools.partial(
self._add, text, character, uid=uid, ts=ts, **kwargs
),
)
except Exception as e:
log.error("memory agent", error="failed to add memory (retried)", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
log.error(
"memory agent",
error="failed to add memory (retried)",
details=e,
text=text[:50],
character=character,
uid=uid,
ts=ts,
**kwargs,
)
def _add(self, text, character=None, ts:str=None, **kwargs):
def _add(self, text, character=None, ts: str = None, **kwargs):
raise NotImplementedError()
@set_processing
@@ -131,44 +168,46 @@ class MemoryAgent(Agent):
while not self._ready_to_add:
await asyncio.sleep(0.1)
log.debug("memory agent add many", len=len(objects))
log.debug("memory agent add many", len=len(objects))
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._add_many, objects)
def _add_many(self, objects: list[dict]):
"""
Add multiple objects to the memory
"""
raise NotImplementedError()
def _delete(self, meta:dict):
def _delete(self, meta: dict):
"""
Delete an object from the memory
"""
raise NotImplementedError()
@set_processing
async def delete(self, meta:dict):
async def delete(self, meta: dict):
"""
Delete an object from the memory
"""
if self.readonly:
log.debug("memory agent", status="readonly")
return
while not self._ready_to_add:
await asyncio.sleep(0.1)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._delete, meta)
@set_processing
async def get(self, text, character=None, **query):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, functools.partial(self._get, text, character, **query))
return await loop.run_in_executor(
None, functools.partial(self._get, text, character, **query)
)
def _get(self, text, character=None, **query):
raise NotImplementedError()
@@ -177,12 +216,14 @@ class MemoryAgent(Agent):
async def get_document(self, id):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._get_document, id)
def _get_document(self, id):
raise NotImplementedError()
def on_archive_add(self, event: events.ArchiveEvent):
asyncio.ensure_future(self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history"))
asyncio.ensure_future(
self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
)
def on_character_state(self, event: events.CharacterStateEvent):
asyncio.ensure_future(
@@ -222,10 +263,10 @@ class MemoryAgent(Agent):
"""
memory_context = []
if not query:
return memory_context
for memory in await self.get(query):
if memory in memory_context:
continue
@@ -239,17 +280,26 @@ class MemoryAgent(Agent):
break
return memory_context
async def query(self, query:str, max_tokens:int=1000, filter:Callable=lambda x:True, **where):
async def query(
self,
query: str,
max_tokens: int = 1000,
filter: Callable = lambda x: True,
**where,
):
"""
Get the character memory context for a given character
"""
try:
return (await self.multi_query([query], max_tokens=max_tokens, filter=filter, **where))[0]
return (
await self.multi_query(
[query], max_tokens=max_tokens, filter=filter, **where
)
)[0]
except IndexError:
return None
async def multi_query(
self,
queries: list[str],
@@ -258,7 +308,7 @@ class MemoryAgent(Agent):
filter: Callable = lambda x: True,
formatter: Callable = lambda x: x,
limit: int = 10,
**where
**where,
):
"""
Get the character memory context for a given character
@@ -266,10 +316,9 @@ class MemoryAgent(Agent):
memory_context = []
for query in queries:
if not query:
continue
i = 0
for memory in await self.get(formatter(query), limit=limit, **where):
if memory in memory_context:
@@ -296,15 +345,13 @@ from .registry import register
@register(condition=lambda: chromadb is not None)
class ChromaDBMemoryAgent(MemoryAgent):
requires_llm_client = False
@property
def ready(self):
if self.embeddings == "openai" and not self.openai_api_key:
return False
if getattr(self, "db_client", None):
return True
return False
@@ -313,80 +360,84 @@ class ChromaDBMemoryAgent(MemoryAgent):
def status(self):
if self.ready:
return "active" if not getattr(self, "processing", False) else "busy"
if self.embeddings == "openai" and not self.openai_api_key:
return "error"
return "waiting"
@property
def agent_details(self):
if self.embeddings == "openai" and not self.openai_api_key:
return "No OpenAI API key set"
return f"ChromaDB: {self.embeddings}"
@property
def embeddings(self):
"""
Returns which embeddings to use
will read from TM_CHROMADB_EMBEDDINGS env variable and default to 'default' using
the default embeddings specified by chromadb.
other values are
- openai: use openai embeddings
- instructor: use instructor embeddings
for `openai`:
you will also need to provide an `OPENAI_API_KEY` env variable
for `instructor`:
you will also need to provide which instructor model to use with the `TM_INSTRUCTOR_MODEL` env variable, which defaults to hkunlp/instructor-xl
additionally you can provide the `TM_INSTRUCTOR_DEVICE` env variable to specify which device to use, which defaults to cpu
"""
embeddings = self.config.get("chromadb").get("embeddings")
assert embeddings in ["default", "openai", "instructor"], f"Unknown embeddings {embeddings}"
assert embeddings in [
"default",
"openai",
"instructor",
], f"Unknown embeddings {embeddings}"
return embeddings
@property
def USE_OPENAI(self):
return self.embeddings == "openai"
@property
def USE_INSTRUCTOR(self):
return self.embeddings == "instructor"
@property
def db_name(self):
return getattr(self, "collection_name", "<unnamed>")
@property
def openai_api_key(self):
return self.config.get("openai",{}).get("api_key")
return self.config.get("openai", {}).get("api_key")
def make_collection_name(self, scene):
if self.USE_OPENAI:
suffix = "-openai"
elif self.USE_INSTRUCTOR:
suffix = "-instructor"
model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
model = self.config.get("chromadb").get(
"instructor_model", "hkunlp/instructor-xl"
)
if "xl" in model:
suffix += "-xl"
elif "large" in model:
suffix += "-large"
else:
suffix = ""
return f"{scene.memory_id}-tm{suffix}"
async def count(self):
@@ -399,9 +450,8 @@ class ChromaDBMemoryAgent(MemoryAgent):
await loop.run_in_executor(None, self._set_db)
def _set_db(self):
self._ready_to_add = False
if not getattr(self, "db_client", None):
log.info("chromadb agent", status="setting up db client to persistent db")
self.db_client = chromadb.PersistentClient(
@@ -409,49 +459,60 @@ class ChromaDBMemoryAgent(MemoryAgent):
)
openai_key = self.openai_api_key
self.collection_name = collection_name = self.make_collection_name(self.scene)
log.info("chromadb agent", status="setting up db", collection_name=collection_name)
log.info(
"chromadb agent", status="setting up db", collection_name=collection_name
)
if self.USE_OPENAI:
if not openai_key:
raise ValueError("You must provide an the openai ai key in the config if you want to use it for chromadb embeddings")
raise ValueError(
"You must provide an the openai ai key in the config if you want to use it for chromadb embeddings"
)
log.info(
"crhomadb", status="using openai", openai_key=openai_key[:5] + "..."
)
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key = openai_key,
api_key=openai_key,
model_name="text-embedding-ada-002",
)
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=openai_ef
)
elif self.USE_INSTRUCTOR:
instructor_device = self.config.get("chromadb").get("instructor_device", "cpu")
instructor_model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
log.info("chromadb", status="using instructor", model=instructor_model, device=instructor_device)
instructor_device = self.config.get("chromadb").get(
"instructor_device", "cpu"
)
instructor_model = self.config.get("chromadb").get(
"instructor_model", "hkunlp/instructor-xl"
)
log.info(
"chromadb",
status="using instructor",
model=instructor_model,
device=instructor_device,
)
# ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2")
ef = embedding_functions.InstructorEmbeddingFunction(
model_name=instructor_model, device=instructor_device
)
log.info("chromadb", status="embedding function ready")
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef
)
log.info("chromadb", status="instructor db ready")
else:
log.info("chromadb", status="using default embeddings")
self.db = self.db_client.get_or_create_collection(collection_name)
self.scene._memory_never_persisted = self.db.count() == 0
log.info("chromadb agent", status="db ready")
self._ready_to_add = True
@@ -459,17 +520,21 @@ class ChromaDBMemoryAgent(MemoryAgent):
def clear_db(self):
if not self.db:
return
log.info("chromadb agent", status="clearing db", collection_name=self.collection_name)
log.info(
"chromadb agent", status="clearing db", collection_name=self.collection_name
)
self.db.delete(where={"source": "talemate"})
def drop_db(self):
if not self.db:
return
log.info("chromadb agent", status="dropping db", collection_name=self.collection_name)
log.info(
"chromadb agent", status="dropping db", collection_name=self.collection_name
)
try:
self.db_client.delete_collection(self.collection_name)
except ValueError as exc:
@@ -479,31 +544,43 @@ class ChromaDBMemoryAgent(MemoryAgent):
def close_db(self, scene):
if not self.db:
return
log.info("chromadb agent", status="closing db", collection_name=self.collection_name)
log.info(
"chromadb agent", status="closing db", collection_name=self.collection_name
)
if not scene.saved and not scene.saved_memory_session_id:
# scene was never saved so we can discard the memory
collection_name = self.make_collection_name(scene)
log.info("chromadb agent", status="discarding memory", collection_name=collection_name)
log.info(
"chromadb agent",
status="discarding memory",
collection_name=collection_name,
)
try:
self.db_client.delete_collection(collection_name)
except ValueError as exc:
log.error("chromadb agent", error="failed to delete collection", details=exc)
log.error(
"chromadb agent", error="failed to delete collection", details=exc
)
elif not scene.saved:
# scene was saved but memory was never persisted
# so we need to remove the memory from the db
self._remove_unsaved_memory()
self.db = None
def _add(self, text, character=None, uid=None, ts:str=None, **kwargs):
def _add(self, text, character=None, uid=None, ts: str = None, **kwargs):
metadatas = []
ids = []
scene = self.scene
if character:
meta = {"character": character.name, "source": "talemate", "session": scene.memory_session_id}
meta = {
"character": character.name,
"source": "talemate",
"session": scene.memory_session_id,
}
if ts:
meta["ts"] = ts
meta.update(kwargs)
@@ -513,7 +590,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
id = uid or f"{character.name}-{self.memory_tracker[character.name]}"
ids = [id]
else:
meta = {"character": "__narrator__", "source": "talemate", "session": scene.memory_session_id}
meta = {
"character": "__narrator__",
"source": "talemate",
"session": scene.memory_session_id,
}
if ts:
meta["ts"] = ts
meta.update(kwargs)
@@ -523,17 +604,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
id = uid or f"__narrator__-{self.memory_tracker['__narrator__']}"
ids = [id]
#log.debug("chromadb agent add", text=text, meta=meta, id=id)
# log.debug("chromadb agent add", text=text, meta=meta, id=id)
self.db.upsert(documents=[text], metadatas=metadatas, ids=ids)
def _add_many(self, objects: list[dict]):
documents = []
metadatas = []
ids = []
scene = self.scene
if not objects:
return
@@ -552,52 +632,50 @@ class ChromaDBMemoryAgent(MemoryAgent):
ids.append(uid)
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
def _delete(self, meta:dict):
def _delete(self, meta: dict):
if "ids" in meta:
log.debug("chromadb agent delete", ids=meta["ids"])
self.db.delete(ids=meta["ids"])
return
where = {"$and": [{k:v} for k,v in meta.items()]}
where = {"$and": [{k: v} for k, v in meta.items()]}
self.db.delete(where=where)
log.debug("chromadb agent delete", meta=meta, where=where)
def _get(self, text, character=None, limit:int=15, **kwargs):
def _get(self, text, character=None, limit: int = 15, **kwargs):
where = {}
# this doesn't work because chromadb currently doesn't match
# non existing fields with $ne (or so it seems)
# where.setdefault("$and", [{"pin_only": {"$ne": True}}])
where.setdefault("$and", [])
character_filtered = False
for k,v in kwargs.items():
for k, v in kwargs.items():
if k == "character":
character_filtered = True
where["$and"].append({k: v})
if character and not character_filtered:
where["$and"].append({"character": character.name})
if len(where["$and"]) == 1:
where = where["$and"][0]
elif not where["$and"]:
where = None
#log.debug("crhomadb agent get", text=text, where=where)
# log.debug("crhomadb agent get", text=text, where=where)
_results = self.db.query(query_texts=[text], where=where, n_results=limit)
#import json
#print(json.dumps(_results["ids"], indent=2))
#print(json.dumps(_results["distances"], indent=2))
# import json
# print(json.dumps(_results["ids"], indent=2))
# print(json.dumps(_results["distances"], indent=2))
results = []
max_distance = 1.5
if self.USE_INSTRUCTOR:
max_distance = 1
@@ -606,24 +684,24 @@ class ChromaDBMemoryAgent(MemoryAgent):
for i in range(len(_results["distances"][0])):
distance = _results["distances"][0][i]
doc = _results["documents"][0][i]
meta = _results["metadatas"][0][i]
ts = meta.get("ts")
# skip pin_only entries
if meta.get("pin_only", False):
continue
if distance < max_distance:
date_prefix = self.convert_ts_to_date_prefix(ts)
raw = doc
if date_prefix:
doc = f"{date_prefix}: {doc}"
doc = MemoryDocument(doc, meta, _results["ids"][0][i], raw)
results.append(doc)
else:
break
@@ -635,45 +713,55 @@ class ChromaDBMemoryAgent(MemoryAgent):
return results
def convert_ts_to_date_prefix(self, ts):
if not ts:
return None
try:
return util.iso8601_diff_to_human(ts, self.scene.ts)
except Exception as e:
log.error("chromadb agent", error="failed to get date prefix", details=e, ts=ts, scene_ts=self.scene.ts)
log.error(
"chromadb agent",
error="failed to get date prefix",
details=e,
ts=ts,
scene_ts=self.scene.ts,
)
return None
def _get_document(self, id) -> dict:
result = self.db.get(ids=[id] if isinstance(id, str) else id)
documents = {}
for idx, doc in enumerate(result["documents"]):
date_prefix = self.convert_ts_to_date_prefix(result["metadatas"][idx].get("ts"))
date_prefix = self.convert_ts_to_date_prefix(
result["metadatas"][idx].get("ts")
)
if date_prefix:
doc = f"{date_prefix}: {doc}"
documents[result["ids"][idx]] = MemoryDocument(doc, result["metadatas"][idx], result["ids"][idx], doc)
documents[result["ids"][idx]] = MemoryDocument(
doc, result["metadatas"][idx], result["ids"][idx], doc
)
return documents
@set_processing
async def remove_unsaved_memory(self):
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._remove_unsaved_memory)
def _remove_unsaved_memory(self):
scene = self.scene
if not scene.memory_session_id:
return
if scene.saved_memory_session_id == self.scene.memory_session_id:
return
log.info("chromadb agent", status="removing unsaved memory", session_id=scene.memory_session_id)
log.info(
"chromadb agent",
status="removing unsaved memory",
session_id=scene.memory_session_id,
)
self._delete({"session": scene.memory_session_id, "source": "talemate"})

View File

@@ -1,41 +1,44 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import dataclasses
import structlog
import random
import talemate.util as util
from talemate.emit import emit
import talemate.emit.async_signals
from talemate.prompts import Prompt
from talemate.agents.base import set_processing as _set_processing, Agent, AgentAction, AgentActionConfig, AgentEmission
from talemate.agents.world_state import TimePassageEmission
from talemate.scene_message import NarratorMessage
from talemate.events import GameLoopActorIterEvent
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
import talemate.client as client
import talemate.emit.async_signals
import talemate.util as util
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission
from talemate.agents.base import set_processing as _set_processing
from talemate.agents.world_state import TimePassageEmission
from talemate.emit import emit
from talemate.events import GameLoopActorIterEvent
from talemate.prompts import Prompt
from talemate.scene_message import NarratorMessage
from .registry import register
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Player, Character
from talemate.tale_mate import Actor, Character, Player
log = structlog.get_logger("talemate.agents.narrator")
@dataclasses.dataclass
class NarratorAgentEmission(AgentEmission):
generation: list[str] = dataclasses.field(default_factory=list)
talemate.emit.async_signals.register(
"agent.narrator.generated"
)
talemate.emit.async_signals.register("agent.narrator.generated")
def set_processing(fn):
"""
Custom decorator that emits the agent status as processing while the function
is running and then emits the result of the function as a NarratorAgentEmission
"""
@_set_processing
async def wrapper(self, *args, **kwargs):
response = await fn(self, *args, **kwargs)
@@ -45,68 +48,70 @@ def set_processing(fn):
)
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
return emission.generation[0]
wrapper.__name__ = fn.__name__
return wrapper
@register()
class NarratorAgent(Agent):
"""
Handles narration of the story
"""
agent_type = "narrator"
verbose_name = "Narrator"
def __init__(
self,
client: client.TaleMateClient,
**kwargs,
):
self.client = client
# agent actions
self.actions = {
"generation_override": AgentAction(
enabled = True,
label = "Generation Override",
description = "Override generation parameters",
config = {
enabled=True,
label="Generation Override",
description="Override generation parameters",
config={
"instructions": AgentActionConfig(
type="text",
label="Instructions",
value="Never wax poetic.",
description="Extra instructions to give to the AI for narrative generation.",
),
}
},
),
"auto_break_repetition": AgentAction(
enabled = True,
label = "Auto Break Repetition",
description = "Will attempt to automatically break AI repetition.",
enabled=True,
label="Auto Break Repetition",
description="Will attempt to automatically break AI repetition.",
),
"narrate_time_passage": AgentAction(
enabled=True,
label="Narrate Time Passage",
enabled=True,
label="Narrate Time Passage",
description="Whenever you indicate passage of time, narrate right after",
config = {
config={
"ask_for_prompt": AgentActionConfig(
type="bool",
label="Guide time narration via prompt",
label="Guide time narration via prompt",
description="Ask the user for a prompt to generate the time passage narration",
value=True,
)
}
},
),
"narrate_dialogue": AgentAction(
enabled=False,
label="Narrate after Dialogue",
enabled=False,
label="Narrate after Dialogue",
description="Narrator will get a chance to narrate after every line of dialogue",
config = {
config={
"ai_dialog": AgentActionConfig(
type="number",
label="AI Dialogue",
label="AI Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.0,
min=0.0,
@@ -115,7 +120,7 @@ class NarratorAgent(Agent):
),
"player_dialog": AgentActionConfig(
type="number",
label="Player Dialogue",
label="Player Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.1,
min=0.0,
@@ -124,34 +129,32 @@ class NarratorAgent(Agent):
),
"generate_dialogue": AgentActionConfig(
type="bool",
label="Allow Dialogue in Narration",
label="Allow Dialogue in Narration",
description="Allow the narrator to generate dialogue in narration",
value=False,
),
}
},
),
}
@property
def extra_instructions(self):
if self.actions["generation_override"].enabled:
return self.actions["generation_override"].config["instructions"].value
return ""
def clean_result(self, result):
"""
Cleans the result of a narration
"""
result = result.strip().strip(":").strip()
if "#" in result:
result = result.split("#")[0]
character_names = [c.name for c in self.scene.get_characters()]
cleaned = []
for line in result.split("\n"):
for character_name in character_names:
@@ -160,71 +163,83 @@ class NarratorAgent(Agent):
cleaned.append(line)
result = "\n".join(cleaned)
#result = util.strip_partial_sentences(result)
# result = util.strip_partial_sentences(result)
return result
def connect(self, scene):
"""
Connect to signals
"""
super().connect(scene)
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
talemate.emit.async_signals.get("agent.world_state.time").connect(
self.on_time_passage
)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
async def on_time_passage(self, event:TimePassageEmission):
async def on_time_passage(self, event: TimePassageEmission):
"""
Handles time passage narration, if enabled
"""
if not self.actions["narrate_time_passage"].enabled:
return
response = await self.narrate_time_passage(event.duration, event.human_duration, event.narrative)
narrator_message = NarratorMessage(response, source=f"narrate_time_passage:{event.duration};{event.narrative}")
response = await self.narrate_time_passage(
event.duration, event.human_duration, event.narrative
)
narrator_message = NarratorMessage(
response, source=f"narrate_time_passage:{event.duration};{event.narrative}"
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
async def on_dialog(self, event:GameLoopActorIterEvent):
async def on_dialog(self, event: GameLoopActorIterEvent):
"""
Handles dialogue narration, if enabled
"""
if not self.actions["narrate_dialogue"].enabled:
return
if event.game_loop.had_passive_narration:
log.debug("narrate on dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
log.debug(
"narrate on dialog",
skip=True,
had_passive_narration=event.game_loop.had_passive_narration,
)
return
narrate_on_ai_chance = self.actions["narrate_dialogue"].config["ai_dialog"].value
narrate_on_player_chance = self.actions["narrate_dialogue"].config["player_dialog"].value
narrate_on_ai_chance = (
self.actions["narrate_dialogue"].config["ai_dialog"].value
)
narrate_on_player_chance = (
self.actions["narrate_dialogue"].config["player_dialog"].value
)
narrate_on_ai = random.random() < narrate_on_ai_chance
narrate_on_player = random.random() < narrate_on_player_chance
log.debug(
"narrate on dialog",
narrate_on_ai=narrate_on_ai,
narrate_on_ai_chance=narrate_on_ai_chance,
"narrate on dialog",
narrate_on_ai=narrate_on_ai,
narrate_on_ai_chance=narrate_on_ai_chance,
narrate_on_player=narrate_on_player,
narrate_on_player_chance=narrate_on_player_chance,
)
if event.actor.character.is_player and not narrate_on_player:
return
if not event.actor.character.is_player and not narrate_on_ai:
return
response = await self.narrate_after_dialogue(event.actor.character)
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
narrator_message = NarratorMessage(
response, source=f"narrate_dialogue:{event.actor.character.name}"
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
event.game_loop.had_passive_narration = True
@set_processing
@@ -237,22 +252,22 @@ class NarratorAgent(Agent):
"narrator.narrate-scene",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
}
},
)
response = response.strip("*")
response = util.strip_partial_sentences(response)
response = f"*{response.strip('*')}*"
return response
@set_processing
async def progress_story(self, narrative_direction:str=None):
async def progress_story(self, narrative_direction: str = None):
"""
Narrate the scene
"""
@@ -260,18 +275,20 @@ class NarratorAgent(Agent):
scene = self.scene
pc = scene.get_player_character()
npcs = list(scene.get_npc_characters())
npc_names= ", ".join([npc.name for npc in npcs])
npc_names = ", ".join([npc.name for npc in npcs])
if narrative_direction is None:
narrative_direction = "Slightly move the current scene forward."
self.scene.log.info("narrative_direction", narrative_direction=narrative_direction)
self.scene.log.info(
"narrative_direction", narrative_direction=narrative_direction
)
response = await Prompt.request(
"narrator.narrate-progress",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"narrative_direction": narrative_direction,
@@ -279,7 +296,7 @@ class NarratorAgent(Agent):
"npcs": npcs,
"npc_names": npc_names,
"extra_instructions": self.extra_instructions,
}
},
)
self.scene.log.info("progress_story", response=response)
@@ -291,11 +308,13 @@ class NarratorAgent(Agent):
if response.count("*") % 2 != 0:
response = response.replace("*", "")
response = f"*{response}*"
return response
@set_processing
async def narrate_query(self, query:str, at_the_end:bool=False, as_narrative:bool=True):
async def narrate_query(
self, query: str, at_the_end: bool = False, as_narrative: bool = True
):
"""
Narrate a specific query
"""
@@ -303,21 +322,21 @@ class NarratorAgent(Agent):
"narrator.narrate-query",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"query": query,
"at_the_end": at_the_end,
"as_narrative": as_narrative,
"extra_instructions": self.extra_instructions,
}
},
)
log.info("narrate_query", response=response)
response = self.clean_result(response.strip())
log.info("narrate_query (after clean)", response=response)
if as_narrative:
response = f"*{response}*"
return response
@set_processing
@@ -330,12 +349,12 @@ class NarratorAgent(Agent):
"narrator.narrate-character",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"character": character,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
}
},
)
response = self.clean_result(response.strip())
@@ -345,54 +364,55 @@ class NarratorAgent(Agent):
@set_processing
async def augment_context(self):
"""
Takes a context history generated via scene.context_history() and augments it with additional information
by asking and answering questions with help from the long term memory.
"""
memory = self.scene.get_helper("memory").agent
questions = await Prompt.request(
"narrator.context-questions",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"extra_instructions": self.extra_instructions,
}
},
)
self.scene.log.info("context_questions", questions=questions)
questions = [q for q in questions.split("\n") if q.strip()]
memory_context = await memory.multi_query(
questions, iterate=2, max_tokens=self.client.max_token_length - 1000
)
answers = await Prompt.request(
"narrator.context-answers",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"memory": memory_context,
"questions": questions,
"extra_instructions": self.extra_instructions,
}
},
)
self.scene.log.info("context_answers", answers=answers)
answers = [a for a in answers.split("\n") if a.strip()]
# return questions and answers
return list(zip(questions, answers))
@set_processing
async def narrate_time_passage(self, duration:str, time_passed:str, narrative:str):
async def narrate_time_passage(
self, duration: str, time_passed: str, narrative: str
):
"""
Narrate a specific character
"""
@@ -401,26 +421,25 @@ class NarratorAgent(Agent):
"narrator.narrate-time-passage",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"duration": duration,
"time_passed": time_passed,
"narrative": narrative,
"extra_instructions": self.extra_instructions,
}
},
)
log.info("narrate_time_passage", response=response)
response = self.clean_result(response.strip())
response = f"*{response}*"
return response
@set_processing
async def narrate_after_dialogue(self, character:Character):
async def narrate_after_dialogue(self, character: Character):
"""
Narrate after a line of dialogue
"""
@@ -429,22 +448,24 @@ class NarratorAgent(Agent):
"narrator.narrate-after-dialogue",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"last_line": str(self.scene.history[-1]),
"extra_instructions": self.extra_instructions,
}
},
)
log.info("narrate_after_dialogue", response=response)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
allow_dialogue = self.actions["narrate_dialogue"].config["generate_dialogue"].value
allow_dialogue = (
self.actions["narrate_dialogue"].config["generate_dialogue"].value
)
if not allow_dialogue:
response = response.split('"')[0].strip()
response = response.replace("*", "")
@@ -452,9 +473,11 @@ class NarratorAgent(Agent):
response = f"*{response}*"
return response
@set_processing
async def narrate_character_entry(self, character:Character, direction:str=None):
async def narrate_character_entry(
self, character: Character, direction: str = None
):
"""
Narrate a character entering the scene
"""
@@ -463,22 +486,22 @@ class NarratorAgent(Agent):
"narrator.narrate-character-entry",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"direction": direction,
"extra_instructions": self.extra_instructions,
}
},
)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
return response
@set_processing
async def narrate_character_exit(self, character:Character, direction:str=None):
async def narrate_character_exit(self, character: Character, direction: str = None):
"""
Narrate a character exiting the scene
"""
@@ -487,20 +510,20 @@ class NarratorAgent(Agent):
"narrator.narrate-character-exit",
self.client,
"narrate",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"direction": direction,
"extra_instructions": self.extra_instructions,
}
},
)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
return response
async def action_to_narration(
self,
action_name: str,
@@ -509,25 +532,35 @@ class NarratorAgent(Agent):
):
# calls self[action_name] and returns the result as a NarratorMessage
# that is pushed to the history
fn = getattr(self, action_name)
narration = await fn(*args, **kwargs)
narrator_message = NarratorMessage(narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":"))
narrator_message = NarratorMessage(
narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":")
)
self.scene.push_history(narrator_message)
return narrator_message
# LLM client related methods. These are called during or after the client
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
def inject_prompt_paramters(
self, prompt_param: dict, kind: str, agent_function_name: str
):
log.debug(
"inject_prompt_paramters",
prompt_param=prompt_param,
kind=kind,
agent_function_name=agent_function_name,
)
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += character_names
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
if auto and not self.actions["auto_break_repetition"].enabled:
return False
return True

View File

@@ -1,26 +1,26 @@
from __future__ import annotations
import asyncio
import re
import time
import traceback
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
import talemate.data_objects as data_objects
import talemate.emit.async_signals
import talemate.util as util
from talemate.events import GameLoopEvent
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage
from talemate.events import GameLoopEvent
from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .base import Agent, AgentAction, AgentActionConfig, set_processing
from .registry import register
import structlog
import time
import re
log = structlog.get_logger("talemate.agents.summarize")
@register()
class SummarizeAgent(Agent):
"""
@@ -36,7 +36,7 @@ class SummarizeAgent(Agent):
def __init__(self, client, **kwargs):
self.client = client
self.actions = {
"archive": AgentAction(
enabled=True,
@@ -67,30 +67,27 @@ class SummarizeAgent(Agent):
type="number",
label="Use preceeding summaries to strengthen context",
description="Number of entries",
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
value=3,
min=0,
max=10,
step=1,
),
}
},
)
}
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, emission:GameLoopEvent):
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called when a conversation is generated
"""
await self.build_archive(self.scene)
def clean_result(self, result):
if "#" in result:
result = result.split("#")[0]
@@ -104,10 +101,10 @@ class SummarizeAgent(Agent):
@set_processing
async def build_archive(self, scene):
end = None
if not self.actions["archive"].enabled:
return
if not scene.archived_history:
start = 0
recent_entry = None
@@ -118,14 +115,16 @@ class SummarizeAgent(Agent):
# meaning we are still at the beginning of the scene
start = 0
else:
start = recent_entry.get("end", 0)+1
start = recent_entry.get("end", 0) + 1
# if there is a recent entry we also collect the 3 most recentries
# as extra context
num_previous = self.actions["archive"].config["include_previous"].value
if recent_entry and num_previous > 0:
extra_context = "\n\n".join([entry["text"] for entry in scene.archived_history[-num_previous:]])
extra_context = "\n\n".join(
[entry["text"] for entry in scene.archived_history[-num_previous:]]
)
else:
extra_context = None
@@ -133,36 +132,42 @@ class SummarizeAgent(Agent):
dialogue_entries = []
ts = "PT0S"
time_passage_termination = False
token_threshold = self.actions["archive"].config["threshold"].value
log.debug("build_archive", start=start, recent_entry=recent_entry)
if recent_entry:
ts = recent_entry.get("ts", ts)
for i in range(start, len(scene.history)):
dialogue = scene.history[i]
#log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
if isinstance(dialogue, DirectorMessage):
if i == start:
start += 1
continue
if isinstance(dialogue, TimePassageMessage):
log.debug("build_archive", time_passage_message=dialogue)
if i == start:
ts = util.iso8601_add(ts, dialogue.ts)
log.debug("build_archive", time_passage_message=dialogue, start=start, i=i, ts=ts)
log.debug(
"build_archive",
time_passage_message=dialogue,
start=start,
i=i,
ts=ts,
)
start += 1
continue
log.debug("build_archive", time_passage_message_termination=dialogue)
time_passage_termination = True
end = i - 1
break
tokens += util.count_tokens(dialogue)
dialogue_entries.append(dialogue)
if tokens > token_threshold: #
@@ -172,39 +177,44 @@ class SummarizeAgent(Agent):
if end is None:
# nothing to archive yet
return
log.debug("build_archive", start=start, end=end, ts=ts, time_passage_termination=time_passage_termination)
log.debug(
"build_archive",
start=start,
end=end,
ts=ts,
time_passage_termination=time_passage_termination,
)
# in order to summarize coherently, we need to determine if there is a favorable
# cutoff point (e.g., the scene naturally ends or shifts meaninfully in the middle
# of the dialogue)
#
# One way to do this is to check if the last line is a TimePassageMessage, which
# indicates a scene change or a significant pause.
#
# indicates a scene change or a significant pause.
#
# If not, we can ask the AI to find a good point of
# termination.
if not time_passage_termination:
# No TimePassageMessage, so we need to ask the AI to find a good point of termination
terminating_line = await self.analyze_dialoge(dialogue_entries)
if terminating_line:
adjusted_dialogue = []
for line in dialogue_entries:
for line in dialogue_entries:
if str(line) in terminating_line:
break
adjusted_dialogue.append(line)
dialogue_entries = adjusted_dialogue
end = start + len(dialogue_entries)-1
end = start + len(dialogue_entries) - 1
if dialogue_entries:
summarized = await self.summarize(
"\n".join(map(str, dialogue_entries)), extra_context=extra_context
)
else:
# AI has likely identified the first line as a scene change, so we can't summarize
# just use the first line
@@ -218,15 +228,20 @@ class SummarizeAgent(Agent):
@set_processing
async def analyze_dialoge(self, dialogue):
response = await Prompt.request("summarizer.analyze-dialogue", self.client, "analyze_freeform", vars={
"dialogue": "\n".join(map(str, dialogue)),
"scene": self.scene,
"max_tokens": self.client.max_token_length,
})
response = await Prompt.request(
"summarizer.analyze-dialogue",
self.client,
"analyze_freeform",
vars={
"dialogue": "\n".join(map(str, dialogue)),
"scene": self.scene,
"max_tokens": self.client.max_token_length,
},
)
response = self.clean_result(response)
return response
@set_processing
async def summarize(
self,
@@ -239,33 +254,40 @@ class SummarizeAgent(Agent):
Summarize the given text
"""
response = await Prompt.request("summarizer.summarize-dialogue", self.client, "summarize", vars={
"dialogue": text,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"summarization_method": self.actions["archive"].config["method"].value if method is None else method,
"extra_context": extra_context or "",
"extra_instructions": extra_instructions or "",
})
self.scene.log.info("summarize", dialogue_length=len(text), summarized_length=len(response))
response = await Prompt.request(
"summarizer.summarize-dialogue",
self.client,
"summarize",
vars={
"dialogue": text,
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"summarization_method": self.actions["archive"].config["method"].value
if method is None
else method,
"extra_context": extra_context or "",
"extra_instructions": extra_instructions or "",
},
)
self.scene.log.info(
"summarize", dialogue_length=len(text), summarized_length=len(response)
)
return self.clean_result(response)
async def build_stepped_archive_for_level(self, level:int):
async def build_stepped_archive_for_level(self, level: int):
"""
WIP - not yet used
This will iterate over existing archived_history entries
and stepped_archived_history entries and summarize based on time duration
indicated between the entries.
The lowest level of summarization (based on token threshold and any time passage)
happens in build_archive. This method is for summarizing furhter levels based on
long time pasages.
Level 0: small timestap summarize (summarizes all token summarizations when time advances +1 day)
Level 1: medium timestap summarize (summarizes all small timestep summarizations when time advances +1 week)
Level 2: large timestap summarize (summarizes all medium timestep summarizations when time advances +1 month)
@@ -273,7 +295,7 @@ class SummarizeAgent(Agent):
Level 4: massive timestap summarize (summarizes all huge timestep summarizations when time advances +10 years)
Level 5: epic timestap summarize (summarizes all massive timestep summarizations when time advances +100 years)
and so on (increasing by a factor of 10 each time)
```
@dataclass
class ArchiveEntry:
@@ -282,35 +304,34 @@ class SummarizeAgent(Agent):
end: int = None
ts: str = None
```
Like token summarization this will use ArchiveEntry and start and end will refer to the entries in the
lower level of summarization.
Ts is the iso8601 timestamp of the start of the summarized period.
"""
# select the list to use for the entries
if level == 0:
entries = self.scene.archived_history
else:
entries = self.scene.stepped_archived_history[level-1]
entries = self.scene.stepped_archived_history[level - 1]
# select the list to summarize new entries to
target = self.scene.stepped_archived_history[level]
if not target:
raise ValueError(f"Invalid level {level}")
# determine the start and end of the period to summarize
if not entries:
return
# determine the time threshold for this level
# first calculate all possible thresholds in iso8601 format, starting with 1 day
thresholds = [
"P1D",
@@ -318,61 +339,65 @@ class SummarizeAgent(Agent):
"P1M",
"P1Y",
]
# TODO: auto extend?
time_threshold_in_seconds = util.iso8601_to_seconds(thresholds[level])
if not time_threshold_in_seconds:
raise ValueError(f"Invalid level {level}")
# determine the most recent summarized entry time, and then find entries
# that are newer than that in the lower list
ts = target[-1].ts if target else entries[0].ts
# determine the most recent entry at the lower level, if its not newer or
# the difference is less than the threshold, then we don't need to summarize
recent_entry = entries[-1]
if util.iso8601_diff(recent_entry.ts, ts) < time_threshold_in_seconds:
return
log.debug("build_stepped_archive", level=level, ts=ts)
# if target is empty, start is 0
# otherwise start is the end of the last entry
start = 0 if not target else target[-1].end
# collect entries starting at start until the combined time duration
# exceeds the threshold
entries_to_summarize = []
for entry in entries[start:]:
entries_to_summarize.append(entry)
if util.iso8601_diff(entry.ts, ts) > time_threshold_in_seconds:
break
# summarize the entries
# we also collect N entries of previous summaries to use as context
num_previous = self.actions["archive"].config["include_previous"].value
if num_previous > 0:
extra_context = "\n\n".join([entry["text"] for entry in target[-num_previous:]])
extra_context = "\n\n".join(
[entry["text"] for entry in target[-num_previous:]]
)
else:
extra_context = None
summarized = await self.summarize(
"\n".join(map(str, entries_to_summarize)), extra_context=extra_context
)
# push summarized entry to target
ts = entries_to_summarize[-1].ts
target.append(data_objects.ArchiveEntry(summarized, start, len(entries_to_summarize)-1, ts=ts))
target.append(
data_objects.ArchiveEntry(
summarized, start, len(entries_to_summarize) - 1, ts=ts
)
)

View File

@@ -1,16 +1,19 @@
from __future__ import annotations
from typing import Union
import asyncio
import httpx
import base64
import functools
import io
import os
import pydantic
import nltk
import tempfile
import base64
import time
import uuid
import functools
from typing import Union
import httpx
import nltk
import pydantic
import structlog
from nltk.tokenize import sent_tokenize
import talemate.config as config
@@ -21,91 +24,84 @@ from talemate.emit.signals import handlers
from talemate.events import GameLoopNewMessageEvent
from talemate.scene_message import CharacterMessage, NarratorMessage
from .base import Agent, set_processing, AgentAction, AgentActionConfig
from .base import Agent, AgentAction, AgentActionConfig, set_processing
from .registry import register
import structlog
import time
try:
from TTS.api import TTS
except ImportError:
TTS = None
log = structlog.get_logger("talemate.agents.tts")#
log = structlog.get_logger("talemate.agents.tts") #
if not TTS:
# TTS installation is massive and requires a lot of dependencies
# so we don't want to require it unless the user wants to use it
log.info("TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api")
log.info(
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
)
def parse_chunks(text):
text = text.replace("...", "__ellipsis__")
chunks = sent_tokenize(text)
cleaned_chunks = []
for chunk in chunks:
chunk = chunk.replace("*","")
chunk = chunk.replace("*", "")
if not chunk:
continue
cleaned_chunks.append(chunk)
for i, chunk in enumerate(cleaned_chunks):
chunk = chunk.replace("__ellipsis__", "...")
cleaned_chunks[i] = chunk
return cleaned_chunks
def clean_quotes(chunk:str):
def clean_quotes(chunk: str):
# if there is an uneven number of quotes, remove the last one if its
# at the end of the chunk. If its in the middle, add a quote to the end
if chunk.count('"') % 2 == 1:
if chunk.endswith('"'):
chunk = chunk[:-1]
else:
chunk += '"'
return chunk
def rejoin_chunks(chunks:list[str], chunk_size:int=250):
return chunk
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
"""
Will combine chunks split by punctuation into a single chunk until
max chunk size is reached
"""
joined_chunks = []
current_chunk = ""
for chunk in chunks:
if len(current_chunk) + len(chunk) > chunk_size:
joined_chunks.append(clean_quotes(current_chunk))
current_chunk = ""
current_chunk += chunk
if current_chunk:
joined_chunks.append(clean_quotes(current_chunk))
return joined_chunks
class Voice(pydantic.BaseModel):
value:str
label:str
value: str
label: str
class VoiceLibrary(pydantic.BaseModel):
api: str
voices: list[Voice] = pydantic.Field(default_factory=list)
last_synced: float = None
@@ -113,31 +109,30 @@ class VoiceLibrary(pydantic.BaseModel):
@register()
class TTSAgent(Agent):
"""
Text to speech agent
"""
agent_type = "tts"
verbose_name = "Voice"
requires_llm_client = False
@classmethod
def config_options(cls, agent=None):
config_options = super().config_options(agent=agent)
if agent:
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
voice.model_dump() for voice in agent.list_voices_sync()
]
return config_options
def __init__(self, **kwargs):
self.is_enabled = False
nltk.download("punkt", quiet=True)
self.voices = {
"elevenlabs": VoiceLibrary(api="elevenlabs"),
"coqui": VoiceLibrary(api="coqui"),
@@ -147,8 +142,8 @@ class TTSAgent(Agent):
self.playback_done_event = asyncio.Event()
self.actions = {
"_config": AgentAction(
enabled=True,
label="Configure",
enabled=True,
label="Configure",
description="TTS agent configuration",
config={
"api": AgentActionConfig(
@@ -169,7 +164,7 @@ class TTSAgent(Agent):
value="default",
label="Narrator Voice",
description="Voice ID/Name to use for TTS",
choices=[]
choices=[],
),
"generate_for_player": AgentActionConfig(
type="bool",
@@ -194,55 +189,54 @@ class TTSAgent(Agent):
value=False,
label="Split generation",
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
)
}
),
},
),
}
self.actions["_config"].model_dump()
handlers["config_saved"].connect(self.on_config_saved)
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return False
@property
def not_ready_reason(self) -> str:
"""
Returns a string explaining why the agent is not ready
"""
if self.ready:
return ""
if self.api == "tts":
if not TTS:
return "TTS not installed"
elif self.requires_token and not self.token:
return "No API token"
elif not self.default_voice_id:
return "No voice selected"
@property
def agent_details(self):
suffix = ""
if not self.ready:
suffix = f" - {self.not_ready_reason}"
else:
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
api = self.api
choices = self.actions["_config"].config["api"].choices
api_label = api
@@ -250,34 +244,33 @@ class TTSAgent(Agent):
if choice["value"] == api:
api_label = choice["label"]
break
return f"{api_label}{suffix}"
@property
def api(self):
return self.actions["_config"].config["api"].value
@property
def token(self):
api = self.api
return self.config.get(api,{}).get("api_key")
return self.config.get(api, {}).get("api_key")
@property
def default_voice_id(self):
return self.actions["_config"].config["voice_id"].value
@property
def requires_token(self):
return self.api != "tts"
@property
def ready(self):
if self.api == "tts":
if not TTS:
return False
return True
return (not self.requires_token or self.token) and self.default_voice_id
@property
@@ -299,106 +292,118 @@ class TTSAgent(Agent):
return 1024
elif self.api == "coqui":
return 250
return 250
def apply_config(self, *args, **kwargs):
try:
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
except KeyError:
api = self.api
api_changed = api != self.api
log.debug("apply_config", api=api, api_changed=api != self.api, current_api=self.api)
log.debug(
"apply_config", api=api, api_changed=api != self.api, current_api=self.api
)
super().apply_config(*args, **kwargs)
if api_changed:
try:
self.actions["_config"].config["voice_id"].value = self.voices[api].voices[0].value
self.actions["_config"].config["voice_id"].value = (
self.voices[api].voices[0].value
)
except IndexError:
self.actions["_config"].config["voice_id"].value = ""
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop_new_message").connect(self.on_game_loop_new_message)
talemate.emit.async_signals.get("game_loop_new_message").connect(
self.on_game_loop_new_message
)
def on_config_saved(self, event):
config = event.data
self.config = config
instance.emit_agent_status(self.__class__, self)
async def on_game_loop_new_message(self, emission:GameLoopNewMessageEvent):
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
"""
Called when a conversation is generated
"""
if not self.enabled or not self.ready:
return
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
return
if isinstance(emission.message, NarratorMessage) and not self.actions["_config"].config["generate_for_narration"].value:
if (
isinstance(emission.message, NarratorMessage)
and not self.actions["_config"].config["generate_for_narration"].value
):
return
if isinstance(emission.message, CharacterMessage):
if emission.message.source == "player" and not self.actions["_config"].config["generate_for_player"].value:
if (
emission.message.source == "player"
and not self.actions["_config"].config["generate_for_player"].value
):
return
elif emission.message.source == "ai" and not self.actions["_config"].config["generate_for_npc"].value:
elif (
emission.message.source == "ai"
and not self.actions["_config"].config["generate_for_npc"].value
):
return
if isinstance(emission.message, CharacterMessage):
character_prefix = emission.message.split(":", 1)[0]
else:
character_prefix = ""
log.info("reactive tts", message=emission.message, character_prefix=character_prefix)
await self.generate(str(emission.message).replace(character_prefix+": ", ""))
log.info(
"reactive tts", message=emission.message, character_prefix=character_prefix
)
def voice(self, voice_id:str) -> Union[Voice, None]:
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
def voice(self, voice_id: str) -> Union[Voice, None]:
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice
return None
def voice_id_to_label(self, voice_id:str):
def voice_id_to_label(self, voice_id: str):
for voice in self.voices[self.api].voices:
if voice.value == voice_id:
return voice.label
return None
def list_voices_sync(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.list_voices())
async def list_voices(self):
if self.requires_token and not self.token:
return []
library = self.voices[self.api]
# TODO: allow re-syncing voices
if library.last_synced:
return library.voices
list_fn = getattr(self, f"_list_voices_{self.api}")
log.info("Listing voices", api=self.api)
library.voices = await list_fn()
library.last_synced = time.time()
# if the current voice cannot be found, reset it
if not self.voice(self.default_voice_id):
self.actions["_config"].config["voice_id"].value = ""
# set loading to false
return library.voices
@@ -407,11 +412,10 @@ class TTSAgent(Agent):
if not self.enabled or not self.ready or not text:
return
self.playback_done_event.set()
generate_fn = getattr(self, f"_generate_{self.api}")
if self.actions["_config"].config["generate_chunks"].value:
chunks = parse_chunks(text)
chunks = rejoin_chunks(chunks)
@@ -427,59 +431,71 @@ class TTSAgent(Agent):
async def generate_chunks(self, generate_fn, chunks):
for chunk in chunks:
chunk = chunk.replace("*","").strip()
chunk = chunk.replace("*", "").strip()
log.info("Generating audio", api=self.api, chunk=chunk)
audio_data = await generate_fn(chunk)
self.play_audio(audio_data)
def play_audio(self, audio_data):
# play audio through the python audio player
#play(audio_data)
emit("audio_queue", data={"audio_data": base64.b64encode(audio_data).decode("utf-8")})
# play(audio_data)
emit(
"audio_queue",
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
)
self.playback_done_event.set() # Signal that playback is finished
# LOCAL
async def _generate_tts(self, text: str) -> Union[bytes, None]:
if not TTS:
return
tts_config = self.config.get("tts",{})
tts_config = self.config.get("tts", {})
model = tts_config.get("model")
device = tts_config.get("device", "cpu")
log.debug("tts local", model=model, device=device)
if not hasattr(self, "tts_instance"):
self.tts_instance = TTS(model).to(device)
tts = self.tts_instance
loop = asyncio.get_event_loop()
voice = self.voice(self.default_voice_id)
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
await loop.run_in_executor(None, functools.partial(tts.tts_to_file, text=text, speaker_wav=voice.value, language="en", file_path=file_path))
#tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
await loop.run_in_executor(
None,
functools.partial(
tts.tts_to_file,
text=text,
speaker_wav=voice.value,
language="en",
file_path=file_path,
),
)
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
with open(file_path, "rb") as f:
return f.read()
async def _list_voices_tts(self) -> dict[str, str]:
return [Voice(**voice) for voice in self.config.get("tts",{}).get("voices",[])]
return [
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
]
# ELEVENLABS
async def _generate_elevenlabs(self, text: str, chunk_size: int = 1024) -> Union[bytes, None]:
async def _generate_elevenlabs(
self, text: str, chunk_size: int = 1024
) -> Union[bytes, None]:
api_key = self.token
if not api_key:
return
@@ -493,11 +509,8 @@ class TTSAgent(Agent):
}
data = {
"text": text,
"model_id": self.config.get("elevenlabs",{}).get("model"),
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.5
}
"model_id": self.config.get("elevenlabs", {}).get("model"),
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
}
response = await client.post(url, json=data, headers=headers, timeout=300)
@@ -514,27 +527,33 @@ class TTSAgent(Agent):
log.error(f"Error generating audio: {response.text}")
async def _list_voices_elevenlabs(self) -> dict[str, str]:
url_voices = "https://api.elevenlabs.io/v1/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Accept": "application/json",
"xi-api-key": self.token,
}
response = await client.get(url_voices, headers=headers, params={"per_page":1000})
response = await client.get(
url_voices, headers=headers, params={"per_page": 1000}
)
speakers = response.json()["voices"]
voices.extend([Voice(value=speaker["voice_id"], label=speaker["name"]) for speaker in speakers])
voices.extend(
[
Voice(value=speaker["voice_id"], label=speaker["name"])
for speaker in speakers
]
)
# sort by name
voices.sort(key=lambda x: x.label)
return voices
return voices
# COQUI STUDIO
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
api_key = self.token
if not api_key:
@@ -545,12 +564,12 @@ class TTSAgent(Agent):
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
"Authorization": f"Bearer {api_key}",
}
data = {
"voice_id": self.default_voice_id,
"text": text,
"language": "en" # Assuming English language for simplicity; this could be parameterized
"language": "en", # Assuming English language for simplicity; this could be parameterized
}
# Make the POST request to Coqui API
@@ -558,7 +577,7 @@ class TTSAgent(Agent):
if response.status_code in [200, 201]:
# Parse the JSON response to get the audio URL
response_data = response.json()
audio_url = response_data.get('audio_url')
audio_url = response_data.get("audio_url")
if audio_url:
# Make a GET request to download the audio file
audio_response = await client.get(audio_url)
@@ -572,7 +591,7 @@ class TTSAgent(Agent):
log.error("No audio URL in response")
else:
log.error(f"Error generating audio: {response.text}")
async def _cleanup_coqui(self, sample_id: str):
api_key = self.token
if not api_key or not sample_id:
@@ -580,9 +599,7 @@ class TTSAgent(Agent):
async with httpx.AsyncClient() as client:
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
headers = {
"Authorization": f"Bearer {api_key}"
}
headers = {"Authorization": f"Bearer {api_key}"}
# Make the DELETE request to Coqui API
response = await client.delete(url, headers=headers)
@@ -590,28 +607,41 @@ class TTSAgent(Agent):
if response.status_code == 204:
log.info(f"Successfully deleted sample with ID: {sample_id}")
else:
log.error(f"Error deleting sample with ID: {sample_id}: {response.text}")
log.error(
f"Error deleting sample with ID: {sample_id}: {response.text}"
)
async def _list_voices_coqui(self) -> dict[str, str]:
url_speakers = "https://app.coqui.ai/api/v2/speakers"
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
voices = []
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"Bearer {self.token}"
}
response = await client.get(url_speakers, headers=headers, params={"per_page":1000})
headers = {"Authorization": f"Bearer {self.token}"}
response = await client.get(
url_speakers, headers=headers, params={"per_page": 1000}
)
speakers = response.json()["result"]
voices.extend([Voice(value=speaker["id"], label=speaker["name"]) for speaker in speakers])
response = await client.get(url_custom_voices, headers=headers, params={"per_page":1000})
voices.extend(
[
Voice(value=speaker["id"], label=speaker["name"])
for speaker in speakers
]
)
response = await client.get(
url_custom_voices, headers=headers, params={"per_page": 1000}
)
custom_voices = response.json()["result"]
voices.extend([Voice(value=voice["id"], label=voice["name"]) for voice in custom_voices])
voices.extend(
[
Voice(value=voice["id"], label=voice["name"])
for voice in custom_voices
]
)
# sort by name
voices.sort(key=lambda x: x.label)
return voices
return voices

View File

@@ -1,46 +1,54 @@
from __future__ import annotations
import dataclasses
import json
import time
import uuid
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import isodate
import structlog
import talemate.emit.async_signals
import talemate.util as util
from talemate.world_state import InsertionMode
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage, ReinforcementMessage
from talemate.emit import emit
from talemate.events import GameLoopEvent
from talemate.instance import get_agent
from talemate.prompts import Prompt
from talemate.scene_message import (
DirectorMessage,
ReinforcementMessage,
TimePassageMessage,
)
from talemate.world_state import InsertionMode
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
from .base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
from .registry import register
import structlog
import isodate
import time
log = structlog.get_logger("talemate.agents.world_state")
talemate.emit.async_signals.register("agent.world_state.time")
@dataclasses.dataclass
class WorldStateAgentEmission(AgentEmission):
"""
Emission class for world state agent
"""
pass
@dataclasses.dataclass
class TimePassageEmission(WorldStateAgentEmission):
"""
Emission class for time passage
"""
duration: str
narrative: str
human_duration: str = None
@register()
class WorldStateAgent(Agent):
@@ -55,26 +63,57 @@ class WorldStateAgent(Agent):
self.client = client
self.is_enabled = True
self.actions = {
"update_world_state": AgentAction(enabled=True, label="Update world state", description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.", config={
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before updating the world state.", value=5, min=1, max=100, step=1)
}),
"update_reinforcements": AgentAction(enabled=True, label="Update state reinforcements", description="Will attempt to update any due state reinforcements.", config={}),
"check_pin_conditions": AgentAction(enabled=True, label="Update conditional context pins", description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.", config={
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before checking conditions.", value=2, min=1, max=100, step=1)
}),
"update_world_state": AgentAction(
enabled=True,
label="Update world state",
description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.",
config={
"turns": AgentActionConfig(
type="number",
label="Turns",
description="Number of turns to wait before updating the world state.",
value=5,
min=1,
max=100,
step=1,
)
},
),
"update_reinforcements": AgentAction(
enabled=True,
label="Update state reinforcements",
description="Will attempt to update any due state reinforcements.",
config={},
),
"check_pin_conditions": AgentAction(
enabled=True,
label="Update conditional context pins",
description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.",
config={
"turns": AgentActionConfig(
type="number",
label="Turns",
description="Number of turns to wait before checking conditions.",
value=2,
min=1,
max=100,
step=1,
)
},
),
}
self.next_update = 0
self.next_pin_check = 0
@property
def enabled(self):
return self.is_enabled
@property
def has_toggle(self):
return True
@property
def experimental(self):
return True
@@ -83,110 +122,121 @@ class WorldStateAgent(Agent):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def advance_time(self, duration:str, narrative:str=None):
async def advance_time(self, duration: str, narrative: str = None):
"""
Emit a time passage message
"""
isodate.parse_duration(duration)
human_duration = util.iso8601_duration_to_human(duration, suffix=" later")
message = TimePassageMessage(ts=duration, message=human_duration)
log.debug("world_state.advance_time", message=message)
self.scene.push_history(message)
self.scene.emit_status()
emit("time", message)
await talemate.emit.async_signals.get("agent.world_state.time").send(
TimePassageEmission(agent=self, duration=duration, narrative=narrative, human_duration=human_duration)
)
async def on_game_loop(self, emission:GameLoopEvent):
emit("time", message)
await talemate.emit.async_signals.get("agent.world_state.time").send(
TimePassageEmission(
agent=self,
duration=duration,
narrative=narrative,
human_duration=human_duration,
)
)
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called when a conversation is generated
"""
if not self.enabled:
return
await self.update_world_state()
await self.auto_update_reinforcments()
await self.auto_check_pin_conditions()
async def auto_update_reinforcments(self):
if not self.enabled:
return
if not self.actions["update_reinforcements"].enabled:
return
await self.update_reinforcements()
async def auto_check_pin_conditions(self):
if not self.enabled:
return
if not self.actions["check_pin_conditions"].enabled:
return
if self.next_pin_check % self.actions["check_pin_conditions"].config["turns"].value != 0 or self.next_pin_check == 0:
if (
self.next_pin_check
% self.actions["check_pin_conditions"].config["turns"].value
!= 0
or self.next_pin_check == 0
):
self.next_pin_check += 1
return
self.next_pin_check = 0
await self.check_pin_conditions()
async def update_world_state(self):
if not self.enabled:
return
if not self.actions["update_world_state"].enabled:
return
log.debug("update_world_state", next_update=self.next_update, turns=self.actions["update_world_state"].config["turns"].value)
log.debug(
"update_world_state",
next_update=self.next_update,
turns=self.actions["update_world_state"].config["turns"].value,
)
scene = self.scene
if self.next_update % self.actions["update_world_state"].config["turns"].value != 0 or self.next_update == 0:
if (
self.next_update % self.actions["update_world_state"].config["turns"].value
!= 0
or self.next_update == 0
):
self.next_update += 1
return
self.next_update = 0
await scene.world_state.request_update()
@set_processing
async def request_world_state(self):
t1 = time.time()
_, world_state = await Prompt.request(
"world_state.request-world-state-v2",
self.client,
"analyze_long",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"object_type": "character",
"object_type_plural": "characters",
}
},
)
self.scene.log.debug("request_world_state", response=world_state, time=time.time() - t1)
self.scene.log.debug(
"request_world_state", response=world_state, time=time.time() - t1
)
return world_state
@set_processing
async def request_world_state_inline(self):
"""
EXPERIMENTAL, Overall the one shot request seems about as coherent as the inline request, but the inline request is is about twice as slow and would need to run on every dialogue line.
"""
@@ -199,14 +249,18 @@ class WorldStateAgent(Agent):
"world_state.request-world-state-inline-items",
self.client,
"analyze_long",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
}
},
)
self.scene.log.debug("request_world_state_inline", marked_items=marked_items_response, time=time.time() - t1)
self.scene.log.debug(
"request_world_state_inline",
marked_items=marked_items_response,
time=time.time() - t1,
)
return marked_items_response
@set_processing
@@ -214,99 +268,107 @@ class WorldStateAgent(Agent):
self,
text: str,
):
response = await Prompt.request(
"world_state.analyze-time-passage",
self.client,
"analyze_freeform_short",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
}
},
)
duration = response.split("\n")[0].split(" ")[0].strip()
if not duration.startswith("P"):
duration = "P"+duration
duration = "P" + duration
return duration
@set_processing
async def analyze_text_and_extract_context(
self,
text: str,
goal: str,
):
response = await Prompt.request(
"world_state.analyze-text-and-extract-context",
self.client,
"analyze_freeform",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"goal": goal,
}
},
)
log.debug("analyze_text_and_extract_context", goal=goal, text=text, response=response)
log.debug(
"analyze_text_and_extract_context", goal=goal, text=text, response=response
)
return response
@set_processing
async def analyze_text_and_extract_context_via_queries(
self,
text: str,
goal: str,
) -> list[str]:
response = await Prompt.request(
"world_state.analyze-text-and-generate-rag-queries",
self.client,
"analyze_freeform",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"goal": goal,
}
},
)
queries = response.split("\n")
memory_agent = get_agent("memory")
context = await memory_agent.multi_query(queries, iterate=3)
log.debug("analyze_text_and_extract_context_via_queries", goal=goal, text=text, queries=queries, context=context)
log.debug(
"analyze_text_and_extract_context_via_queries",
goal=goal,
text=text,
queries=queries,
context=context,
)
return context
@set_processing
async def analyze_and_follow_instruction(
self,
text: str,
instruction: str,
):
response = await Prompt.request(
"world_state.analyze-text-and-follow-instruction",
self.client,
"analyze_freeform",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"instruction": instruction,
}
},
)
log.debug("analyze_and_follow_instruction", instruction=instruction, text=text, response=response)
log.debug(
"analyze_and_follow_instruction",
instruction=instruction,
text=text,
response=response,
)
return response
@set_processing
@@ -315,50 +377,52 @@ class WorldStateAgent(Agent):
text: str,
query: str,
):
response = await Prompt.request(
"world_state.analyze-text-and-answer-question",
self.client,
"analyze_freeform",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"query": query,
}
},
)
log.debug("analyze_text_and_answer_question", query=query, text=text, response=response)
log.debug(
"analyze_text_and_answer_question",
query=query,
text=text,
response=response,
)
return response
@set_processing
async def identify_characters(
self,
text: str = None,
):
"""
Attempts to identify characters in the given text.
"""
_, data = await Prompt.request(
"world_state.identify-characters",
self.client,
"analyze",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
}
},
)
log.debug("identify_characters", text=text, data=data)
return data
def _parse_character_sheet(self, response):
data = {}
for line in response.split("\n"):
if not line.strip():
@@ -367,128 +431,131 @@ class WorldStateAgent(Agent):
break
name, value = line.split(":", 1)
data[name.strip()] = value.strip()
return data
@set_processing
async def extract_character_sheet(
self,
name:str,
text:str = None,
name: str,
text: str = None,
):
"""
Attempts to extract a character sheet from the given text.
"""
response = await Prompt.request(
"world_state.extract-character-sheet",
self.client,
"create",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"text": text,
"name": name,
}
},
)
# loop through each line in response and if it contains a : then extract
# the left side as an attribute name and the right side as the value
#
# break as soon as a non-empty line is found that doesn't contain a :
return self._parse_character_sheet(response)
@set_processing
async def match_character_names(self, names:list[str]):
async def match_character_names(self, names: list[str]):
"""
Attempts to match character names.
"""
_, response = await Prompt.request(
"world_state.match-character-names",
self.client,
"analyze_long",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"names": names,
}
},
)
log.debug("match_character_names", names=names, response=response)
return response
@set_processing
async def update_reinforcements(self, force:bool=False):
async def update_reinforcements(self, force: bool = False):
"""
Queries due worldstate re-inforcements
"""
for reinforcement in self.scene.world_state.reinforce:
if reinforcement.due <= 0 or force:
await self.update_reinforcement(reinforcement.question, reinforcement.character)
await self.update_reinforcement(
reinforcement.question, reinforcement.character
)
else:
reinforcement.due -= 1
@set_processing
async def update_reinforcement(self, question:str, character:str=None, reset:bool=False):
async def update_reinforcement(
self, question: str, character: str = None, reset: bool = False
):
"""
Queries a single re-inforcement
"""
message = None
idx, reinforcement = await self.scene.world_state.find_reinforcement(question, character)
idx, reinforcement = await self.scene.world_state.find_reinforcement(
question, character
)
if not reinforcement:
return
source = f"{reinforcement.question}:{reinforcement.character if reinforcement.character else ''}"
if reset and reinforcement.insert == "sequential":
self.scene.pop_history(typ="reinforcement", source=source, all=True)
answer = await Prompt.request(
"world_state.update-reinforcements",
self.client,
"analyze_freeform",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"question": reinforcement.question,
"instructions": reinforcement.instructions or "",
"character": self.scene.get_character(reinforcement.character) if reinforcement.character else None,
"character": self.scene.get_character(reinforcement.character)
if reinforcement.character
else None,
"answer": (reinforcement.answer if not reset else None) or "",
"reinforcement": reinforcement,
}
},
)
reinforcement.answer = answer
reinforcement.due = reinforcement.interval
# remove any recent previous reinforcement message with same question
# to avoid overloading the near history with reinforcement messages
if not reset:
self.scene.pop_history(typ="reinforcement", source=source, max_iterations=10)
self.scene.pop_history(
typ="reinforcement", source=source, max_iterations=10
)
if reinforcement.insert == "sequential":
# insert the reinforcement message at the current position
message = ReinforcementMessage(message=answer, source=source)
log.debug("update_reinforcement", message=message, reset=reset)
self.scene.push_history(message)
# if reinforcement has a character name set, update the character detail
if reinforcement.character:
character = self.scene.get_character(reinforcement.character)
await character.set_detail(reinforcement.question, answer)
else:
# set world entry
await self.scene.world_state_manager.save_world_entry(
@@ -496,20 +563,19 @@ class WorldStateAgent(Agent):
reinforcement.as_context_line,
{},
)
self.scene.world_state.emit()
return message
return message
@set_processing
async def check_pin_conditions(
self,
):
"""
Checks if any context pin conditions
"""
pins_with_condition = {
entry_id: {
"condition": pin.condition,
@@ -518,41 +584,47 @@ class WorldStateAgent(Agent):
for entry_id, pin in self.scene.world_state.pins.items()
if pin.condition
}
if not pins_with_condition:
return
first_entry_id = list(pins_with_condition.keys())[0]
_, answers = await Prompt.request(
"world_state.check-pin-conditions",
self.client,
"analyze",
vars = {
vars={
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"previous_states": json.dumps(pins_with_condition,indent=2),
"coercion": {first_entry_id:{ "condition": "" }},
}
"previous_states": json.dumps(pins_with_condition, indent=2),
"coercion": {first_entry_id: {"condition": ""}},
},
)
world_state = self.scene.world_state
state_change = False
state_change = False
for entry_id, answer in answers.items():
if entry_id not in world_state.pins:
log.warning("check_pin_conditions", entry_id=entry_id, answer=answer, msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)")
log.warning(
"check_pin_conditions",
entry_id=entry_id,
answer=answer,
msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)",
)
continue
log.info("check_pin_conditions", entry_id=entry_id, answer=answer)
state = answer.get("state")
if state is True or (isinstance(state, str) and state.lower() in ["true", "yes", "y"]):
if state is True or (
isinstance(state, str) and state.lower() in ["true", "yes", "y"]
):
prev_state = world_state.pins[entry_id].condition_state
world_state.pins[entry_id].condition_state = True
world_state.pins[entry_id].active = True
if prev_state != world_state.pins[entry_id].condition_state:
state_change = True
else:
@@ -560,49 +632,50 @@ class WorldStateAgent(Agent):
world_state.pins[entry_id].condition_state = False
world_state.pins[entry_id].active = False
state_change = True
if state_change:
await self.scene.load_active_pins()
self.scene.emit_status()
@set_processing
async def summarize_and_pin(self, message_id:int, num_messages:int=3) -> str:
async def summarize_and_pin(self, message_id: int, num_messages: int = 3) -> str:
"""
Will take a message index and then walk back N messages
summarizing the scene and pinning it to the context.
"""
creator = get_agent("creator")
summarizer = get_agent("summarizer")
message_index = self.scene.message_index(message_id)
text = self.scene.snapshot(lines=num_messages, start=message_index)
extra_context = self.scene.snapshot(lines=50, start=message_index-num_messages)
extra_context = self.scene.snapshot(
lines=50, start=message_index - num_messages
)
summary = await summarizer.summarize(
text,
text,
extra_context=extra_context,
method="short",
extra_instructions="Pay particularly close attention to decisions, agreements or promises made.",
)
entry_id = util.clean_id(await creator.generate_title(summary))
ts = self.scene.ts
log.debug(
"summarize_and_pin",
message_id=message_id,
message_index=message_index,
num_messages=num_messages,
num_messages=num_messages,
summary=summary,
entry_id=entry_id,
ts=ts,
)
await self.scene.world_state_manager.save_world_entry(
entry_id,
summary,
@@ -610,49 +683,49 @@ class WorldStateAgent(Agent):
"ts": ts,
},
)
await self.scene.world_state_manager.set_pin(
entry_id,
active=True,
)
await self.scene.load_active_pins()
self.scene.emit_status()
@set_processing
async def is_character_present(self, character:str) -> bool:
async def is_character_present(self, character: str) -> bool:
"""
Check if a character is present in the scene
Arguments:
- `character`: The character to check.
"""
if len(self.scene.history) < 10:
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
else:
text = self.scene.snapshot(lines=50)
is_present = await self.analyze_text_and_answer_question(
text=text,
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
)
return is_present.lower().startswith("y")
@set_processing
async def is_character_leaving(self, character:str) -> bool:
"""
Check if a character is leaving the scene
Arguments:
- `character`: The character to check.
"""
if len(self.scene.history) < 10:
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
else:
text = self.scene.snapshot(lines=50)
is_present = await self.analyze_text_and_answer_question(
text=text,
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
)
return is_present.lower().startswith("y")
@set_processing
async def is_character_leaving(self, character: str) -> bool:
"""
Check if a character is leaving the scene
Arguments:
- `character`: The character to check.
"""
if len(self.scene.history) < 10:
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
else:
text = self.scene.snapshot(lines=50)
@@ -660,5 +733,5 @@ class WorldStateAgent(Agent):
text=text,
query=f"Is {character} leaving the current scene? Answert with 'yes' or 'no'.",
)
return is_leaving.lower().startswith("y")
return is_leaving.lower().startswith("y")

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import dataclasses
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from talemate import Scene
import structlog
__all__ = ["AutomatedAction", "register", "initialize_for_scene"]
@@ -13,50 +14,64 @@ log = structlog.get_logger("talemate.automated_action")
AUTOMATED_ACTIONS = {}
def initialize_for_scene(scene:Scene):
def initialize_for_scene(scene: Scene):
for uid, config in AUTOMATED_ACTIONS.items():
scene.automated_actions[uid] = config.cls(
scene,
uid=uid,
frequency=config.frequency,
call_initially=config.call_initially,
enabled=config.enabled
enabled=config.enabled,
)
@dataclasses.dataclass
class AutomatedActionConfig:
uid:str
cls:AutomatedAction
frequency:int=5
call_initially:bool=False
enabled:bool=True
uid: str
cls: AutomatedAction
frequency: int = 5
call_initially: bool = False
enabled: bool = True
class register:
def __init__(self, uid:str, frequency:int=5, call_initially:bool=False, enabled:bool=True):
def __init__(
self,
uid: str,
frequency: int = 5,
call_initially: bool = False,
enabled: bool = True,
):
self.uid = uid
self.frequency = frequency
self.call_initially = call_initially
self.enabled = enabled
def __call__(self, action:AutomatedAction):
def __call__(self, action: AutomatedAction):
AUTOMATED_ACTIONS[self.uid] = AutomatedActionConfig(
self.uid,
action,
frequency=self.frequency,
call_initially=self.call_initially,
enabled=self.enabled
self.uid,
action,
frequency=self.frequency,
call_initially=self.call_initially,
enabled=self.enabled,
)
return action
class AutomatedAction:
"""
An action that will be executed every n turns
"""
def __init__(self, scene:Scene, frequency:int=5, call_initially:bool=False, uid:str=None, enabled:bool=True):
def __init__(
self,
scene: Scene,
frequency: int = 5,
call_initially: bool = False,
uid: str = None,
enabled: bool = True,
):
self.scene = scene
self.enabled = enabled
self.frequency = frequency
@@ -64,14 +79,19 @@ class AutomatedAction:
self.uid = uid
if call_initially:
self.turns = frequency
async def __call__(self):
log.debug("automated_action", uid=self.uid, enabled=self.enabled, frequency=self.frequency, turns=self.turns)
log.debug(
"automated_action",
uid=self.uid,
enabled=self.enabled,
frequency=self.frequency,
turns=self.turns,
)
if not self.enabled:
return False
if self.turns % self.frequency == 0:
result = await self.action()
log.debug("automated_action", result=result)
@@ -79,10 +99,9 @@ class AutomatedAction:
# action could not be performed at this turn, we will try again next turn
return False
self.turns += 1
async def action(self) -> Any:
"""
Override this method to implement your action.
"""
raise NotImplementedError()
raise NotImplementedError()

View File

@@ -1,32 +1,34 @@
from typing import Union, TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from talemate.instance import get_agent
if TYPE_CHECKING:
from talemate.tale_mate import Scene, Character, Actor
from talemate.tale_mate import Actor, Character, Scene
__all__ = [
"deactivate_character",
"activate_character",
]
async def deactivate_character(scene:"Scene", character:Union[str, "Character"]):
async def deactivate_character(scene: "Scene", character: Union[str, "Character"]):
"""
Deactivates a character
Arguments:
- `scene`: The scene to deactivate the character from
- `character`: The character to deactivate. Can be a string (the character's name) or a Character object
"""
if isinstance(character, str):
character = scene.get_character(character)
if character.is_player:
# can't deactivate the player
return False
if character.name in scene.inactive_characters:
# already deactivated
return False
@@ -34,24 +36,24 @@ async def deactivate_character(scene:"Scene", character:Union[str, "Character"])
await scene.remove_actor(character.actor)
scene.inactive_characters[character.name] = character
async def activate_character(scene:"Scene", character:Union[str, "Character"]):
async def activate_character(scene: "Scene", character: Union[str, "Character"]):
"""
Activates a character
Arguments:
- `scene`: The scene to activate the character in
- `character`: The character to activate. Can be a string (the character's name) or a Character object
"""
if isinstance(character, str):
character = scene.get_character(character)
if character.name not in scene.inactive_characters:
# already activated
return False
actor = scene.Actor(character, get_agent("conversation"))
await scene.add_actor(actor)
del scene.inactive_characters[character.name]

View File

@@ -2,15 +2,13 @@ import argparse
import asyncio
import glob
import os
import structlog
import structlog
from dotenv import load_dotenv
import talemate.instance as instance
from talemate import Actor, Character, Helper, Player, Scene
from talemate.agents import (
ConversationAgent,
)
from talemate.agents import ConversationAgent
from talemate.client import OpenAIClient, TextGeneratorWebuiClient
from talemate.emit.console import Console
from talemate.load import (
@@ -129,7 +127,6 @@ async def run_console_session(parser, args):
default_client = None
if "textgenwebui" in clients.values() or args.client == "textgenwebui":
# Init the TextGeneratorWebuiClient with ConversationAgent and create an actor
textgenwebui_api_url = args.textgenwebui_url
@@ -145,7 +142,6 @@ async def run_console_session(parser, args):
clients[client_name] = text_generator_webui_client
if "openai" in clients.values() or args.client == "openai":
openai_client = OpenAIClient()
for client_name, client_typ in clients.items():

View File

@@ -1,7 +1,8 @@
import os
import talemate.client.runpod
from talemate.client.lmstudio import LMStudioClient
from talemate.client.openai import OpenAIClient
from talemate.client.openai_compat import OpenAICompatibleClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.textgenwebui import TextGeneratorWebuiClient
from talemate.client.lmstudio import LMStudioClient
from talemate.client.openai_compat import OpenAICompatibleClient
import talemate.client.runpod

View File

@@ -1,26 +1,26 @@
"""
A unified client base, based on the openai API
"""
import logging
import random
import time
import pydantic
from typing import Callable, Union
import pydantic
import structlog
import logging
from openai import AsyncOpenAI, PermissionDeniedError
from talemate.emit import emit
import talemate.instance as instance
import talemate.client.presets as presets
import talemate.client.system_prompts as system_prompts
import talemate.instance as instance
import talemate.util as util
from talemate.agents.context import active_agent
from talemate.client.context import client_context_attribute
from talemate.client.model_prompts import model_prompt
from talemate.agents.context import active_agent
from talemate.emit import emit
# Set up logging level for httpx to WARNING to suppress debug logs.
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
REMOTE_SERVICES = [
# TODO: runpod.py should add this to the list
@@ -29,22 +29,24 @@ REMOTE_SERVICES = [
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
class ErrorAction(pydantic.BaseModel):
title:str
action_name:str
icon:str = "mdi-error"
arguments:list = []
title: str
action_name: str
icon: str = "mdi-error"
arguments: list = []
class Defaults(pydantic.BaseModel):
api_url:str = "http://localhost:5000"
max_token_length:int = 4096
api_url: str = "http://localhost:5000"
max_token_length: int = 4096
class ClientBase:
api_url: str
model_name: str
api_key: str = None
name:str = None
name: str = None
enabled: bool = True
current_status: str = None
max_token_length: int = 4096
@@ -54,19 +56,19 @@ class ClientBase:
auto_break_repetition_enabled: bool = True
client_type = "base"
class Meta(pydantic.BaseModel):
experimental:Union[None,str] = None
defaults:Defaults = Defaults()
title:str = "Client"
name_prefix:str = "Client"
experimental: Union[None, str] = None
defaults: Defaults = Defaults()
title: str = "Client"
name_prefix: str = "Client"
enable_api_auth: bool = False
requires_prompt_template: bool = True
def __init__(
self,
api_url: str = None,
name = None,
name=None,
**kwargs,
):
self.api_url = api_url
@@ -75,87 +77,82 @@ class ClientBase:
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
self.set_client(max_token_length=self.max_token_length)
def __str__(self):
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
@property
def experimental(self):
return False
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def prompt_template(self, sys_msg, prompt):
"""
Applies the appropriate prompt template for the model.
"""
if not self.model_name:
self.log.warning("prompt template not applied", reason="no model loaded")
return f"{sys_msg}\n{prompt}"
return model_prompt(self.model_name, sys_msg, prompt)[0]
return model_prompt(self.model_name, sys_msg, prompt)[0]
def prompt_template_example(self):
if not getattr(self, "model_name", None):
return None, None
return model_prompt(self.model_name, "sysmsg", "prompt<|BOT|>{LLM coercion}")
def reconfigure(self, **kwargs):
"""
Reconfigures the client.
Keyword Arguments:
- api_url: the API URL to use
- max_token_length: the max token length to use
- enabled: whether the client is enabled
"""
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
if kwargs.get("max_token_length"):
self.max_token_length = kwargs["max_token_length"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
def toggle_disabled_if_remote(self):
"""
If the client is targeting a remote recognized service, this
will disable the client.
"""
for service in REMOTE_SERVICES:
if service in self.api_url:
if self.enabled:
self.log.warn("remote service unreachable, disabling client", client=self.name)
self.log.warn(
"remote service unreachable, disabling client", client=self.name
)
self.enabled = False
return True
return False
def get_system_message(self, kind: str) -> str:
"""
Returns the appropriate system message for the given kind of generation
Arguments:
- kind: the kind of generation
"""
# TODO: make extensible
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
@@ -180,16 +177,14 @@ class ClientBase:
return system_prompts.ANALYST
if "summarize" in kind:
return system_prompts.SUMMARIZE
return system_prompts.BASIC
def emit_status(self, processing: bool = None):
"""
Sets and emits the client status.
"""
if processing is not None:
self.processing = processing
@@ -205,12 +200,12 @@ class ClientBase:
else:
model_name = "No model loaded"
status = "warning"
status_change = status != self.current_status
self.current_status = status
prompt_template_example, prompt_template_file = self.prompt_template_example()
emit(
"client_status",
message=self.client_type,
@@ -220,24 +215,25 @@ class ClientBase:
data={
"api_key": self.api_key,
"prompt_template_example": prompt_template_example,
"has_prompt_template": (prompt_template_file and prompt_template_file != "default.jinja2"),
"has_prompt_template": (
prompt_template_file and prompt_template_file != "default.jinja2"
),
"template_file": prompt_template_file,
"meta": self.Meta().model_dump(),
"error_action": None,
}
},
)
if status_change:
instance.emit_agent_status_by_client(self)
async def get_model_name(self):
models = await self.client.models.list()
try:
return models.data[0].id
except IndexError:
return None
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
@@ -246,12 +242,12 @@ class ClientBase:
"""
if self.processing:
return
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
self.model_name = await self.get_model_name()
except Exception as e:
@@ -261,62 +257,66 @@ class ClientBase:
self.toggle_disabled_if_remote()
self.emit_status()
return
self.connected = True
if not self.model_name or self.model_name == "None":
self.log.warning("client model not loaded", client=self)
self.emit_status()
return
self.emit_status()
def generate_prompt_parameters(self, kind:str):
def generate_prompt_parameters(self, kind: str):
parameters = {}
self.tune_prompt_parameters(
presets.configure(parameters, kind, self.max_token_length),
kind
presets.configure(parameters, kind, self.max_token_length), kind
)
return parameters
def tune_prompt_parameters(self, parameters:dict, kind:str):
def tune_prompt_parameters(self, parameters: dict, kind: str):
parameters["stream"] = False
if client_context_attribute("nuke_repetition") > 0.0 and self.jiggle_enabled_for(kind):
self.jiggle_randomness(parameters, offset=client_context_attribute("nuke_repetition"))
if client_context_attribute(
"nuke_repetition"
) > 0.0 and self.jiggle_enabled_for(kind):
self.jiggle_randomness(
parameters, offset=client_context_attribute("nuke_repetition")
)
fn_tune_kind = getattr(self, f"tune_prompt_parameters_{kind}", None)
if fn_tune_kind:
fn_tune_kind(parameters)
agent_context = active_agent.get()
if agent_context.agent:
agent_context.agent.inject_prompt_paramters(parameters, kind, agent_context.action)
def tune_prompt_parameters_conversation(self, parameters:dict):
agent_context.agent.inject_prompt_paramters(
parameters, kind, agent_context.action
)
def tune_prompt_parameters_conversation(self, parameters: dict):
conversation_context = client_context_attribute("conversation")
parameters["max_tokens"] = conversation_context.get("length", 96)
dialog_stopping_strings = [
f"{character}:" for character in conversation_context["other_characters"]
]
if "extra_stopping_strings" in parameters:
parameters["extra_stopping_strings"] += dialog_stopping_strings
else:
parameters["extra_stopping_strings"] = dialog_stopping_strings
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.completions.create(prompt=prompt.strip(" "), **parameters)
response = await self.client.completions.create(
prompt=prompt.strip(" "), **parameters
)
return response.get("choices", [{}])[0].get("text", "")
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
@@ -324,85 +324,97 @@ class ClientBase:
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message="Error during generation (check logs)", status="error")
emit(
"status", message="Error during generation (check logs)", status="error"
)
return ""
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x, retries:int=2
self,
prompt: str,
kind: str = "conversation",
finalize: Callable = lambda x: x,
retries: int = 2,
) -> str:
"""
Send a prompt to the AI and return its response.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
try:
self.emit_status(processing=True)
await self.status()
prompt_param = self.generate_prompt_parameters(kind)
finalized_prompt = self.prompt_template(self.get_system_message(kind), prompt).strip(" ")
finalized_prompt = self.prompt_template(
self.get_system_message(kind), prompt
).strip(" ")
prompt_param = finalize(prompt_param)
token_length = self.count_tokens(finalized_prompt)
time_start = time.time()
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
self.log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length, parameters=prompt_param)
response = await self.generate(
self.repetition_adjustment(finalized_prompt),
prompt_param,
kind
self.log.debug(
"send_prompt",
token_length=token_length,
max_token_length=self.max_token_length,
parameters=prompt_param,
)
response, finalized_prompt = await self.auto_break_repetition(finalized_prompt, prompt_param, response, kind, retries)
response = await self.generate(
self.repetition_adjustment(finalized_prompt), prompt_param, kind
)
response, finalized_prompt = await self.auto_break_repetition(
finalized_prompt, prompt_param, response, kind, retries
)
time_end = time.time()
# stopping strings sometimes get appended to the end of the response anyways
# split the response by the first stopping string and take the first part
for stopping_string in STOPPING_STRINGS + extra_stopping_strings:
if stopping_string in response:
response = response.split(stopping_string)[0]
break
emit("prompt_sent", data={
"kind": kind,
"prompt": finalized_prompt,
"response": response,
"prompt_tokens": token_length,
"response_tokens": self.count_tokens(response),
"time": time_end - time_start,
})
emit(
"prompt_sent",
data={
"kind": kind,
"prompt": finalized_prompt,
"response": response,
"prompt_tokens": token_length,
"response_tokens": self.count_tokens(response),
"time": time_end - time_start,
},
)
return response
finally:
self.emit_status(processing=False)
async def auto_break_repetition(
self,
finalized_prompt:str,
prompt_param:dict,
response:str,
kind:str,
retries:int,
pad_max_tokens:int=32,
self,
finalized_prompt: str,
prompt_param: dict,
response: str,
kind: str,
retries: int,
pad_max_tokens: int = 32,
) -> str:
"""
If repetition breaking is enabled, this will retry the prompt if its
response is too similar to other messages in the prompt
This requires the agent to have the allow_repetition_break method
and the jiggle_enabled_for method and the client to have the
auto_break_repetition_enabled attribute set to True
Arguments:
- finalized_prompt: the prompt that was sent
@@ -411,47 +423,46 @@ class ClientBase:
- kind: the kind of generation
- retries: the number of retries left
- pad_max_tokens: increase response max_tokens by this amount per iteration
Returns:
- the response
"""
if not self.auto_break_repetition_enabled:
return response, finalized_prompt
agent_context = active_agent.get()
if self.jiggle_enabled_for(kind, auto=True):
# check if the response is a repetition
# using the default similarity threshold of 98, meaning it needs
# to be really similar to be considered a repetition
is_repetition, similarity_score, matched_line = util.similarity_score(
response,
finalized_prompt.split("\n"),
similarity_threshold=80
response, finalized_prompt.split("\n"), similarity_threshold=80
)
if not is_repetition:
# not a repetition, return the response
self.log.debug("send_prompt no similarity", similarity_score=similarity_score)
finalized_prompt = self.repetition_adjustment(finalized_prompt, is_repetitive=False)
return response, finalized_prompt
while is_repetition and retries > 0:
# it's a repetition, retry the prompt with adjusted parameters
self.log.warn(
"send_prompt similarity retry",
agent=agent_context.agent.agent_type,
similarity_score=similarity_score,
retries=retries
self.log.debug(
"send_prompt no similarity", similarity_score=similarity_score
)
finalized_prompt = self.repetition_adjustment(
finalized_prompt, is_repetitive=False
)
return response, finalized_prompt
while is_repetition and retries > 0:
# it's a repetition, retry the prompt with adjusted parameters
self.log.warn(
"send_prompt similarity retry",
agent=agent_context.agent.agent_type,
similarity_score=similarity_score,
retries=retries,
)
# first we apply the client's randomness jiggle which will adjust
# parameters like temperature and repetition_penalty, depending
# on the client
@@ -459,90 +470,94 @@ class ClientBase:
# this is a cumulative adjustment, so it will add to the previous
# iteration's adjustment, this also means retries should be kept low
# otherwise it will get out of hand and start generating nonsense
self.jiggle_randomness(prompt_param, offset=0.5)
# then we pad the max_tokens by the pad_max_tokens amount
prompt_param["max_tokens"] += pad_max_tokens
# send the prompt again
# we use the repetition_adjustment method to further encourage
# the AI to break the repetition on its own as well.
finalized_prompt = self.repetition_adjustment(finalized_prompt, is_repetitive=True)
response = retried_response = await self.generate(
finalized_prompt,
prompt_param,
kind
finalized_prompt = self.repetition_adjustment(
finalized_prompt, is_repetitive=True
)
self.log.debug("send_prompt dedupe sentences", response=response, matched_line=matched_line)
response = retried_response = await self.generate(
finalized_prompt, prompt_param, kind
)
self.log.debug(
"send_prompt dedupe sentences",
response=response,
matched_line=matched_line,
)
# a lot of the times the response will now contain the repetition + something new
# so we dedupe the response to remove the repetition on sentences level
response = util.dedupe_sentences(response, matched_line, similarity_threshold=85, debug=True)
self.log.debug("send_prompt dedupe sentences (after)", response=response)
response = util.dedupe_sentences(
response, matched_line, similarity_threshold=85, debug=True
)
self.log.debug(
"send_prompt dedupe sentences (after)", response=response
)
# deduping may have removed the entire response, so we check for that
if not util.strip_partial_sentences(response).strip():
# if the response is empty, we set the response to the original
# and try again next loop
response = retried_response
# check if the response is a repetition again
is_repetition, similarity_score, matched_line = util.similarity_score(
response,
finalized_prompt.split("\n"),
similarity_threshold=80
response, finalized_prompt.split("\n"), similarity_threshold=80
)
retries -= 1
return response, finalized_prompt
def count_tokens(self, content:str):
def count_tokens(self, content: str):
return util.count_tokens(content)
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
def jiggle_enabled_for(self, kind:str, auto:bool=False) -> bool:
def jiggle_enabled_for(self, kind: str, auto: bool = False) -> bool:
agent_context = active_agent.get()
agent = agent_context.agent
if not agent:
return False
return agent.allow_repetition_break(kind, agent_context.action, auto=auto)
def repetition_adjustment(self, prompt:str, is_repetitive:bool=False):
def repetition_adjustment(self, prompt: str, is_repetitive: bool = False):
"""
Breaks the prompt into lines and checkse each line for a match with
[$REPETITION|{repetition_adjustment}].
On match and if is_repetitive is True, the line is removed from the prompt and
replaced with the repetition_adjustment.
On match and if is_repetitive is False, the line is removed from the prompt.
On match and if is_repetitive is False, the line is removed from the prompt.
"""
lines = prompt.split("\n")
new_lines = []
for line in lines:
if line.startswith("[$REPETITION|"):
if is_repetitive:
@@ -551,5 +566,5 @@ class ClientBase:
new_lines.append("")
else:
new_lines.append(line)
return "\n".join(new_lines)
return "\n".join(new_lines)

View File

@@ -1,6 +1,7 @@
import pydantic
from enum import Enum
import pydantic
__all__ = [
"ClientType",
"ClientBootstrap",
@@ -10,8 +11,10 @@ __all__ = [
LISTS = {}
class ClientType(str, Enum):
"""Client type enum."""
textgen = "textgenwebui"
automatic1111 = "automatic1111"
@@ -20,43 +23,42 @@ class ClientBootstrap(pydantic.BaseModel):
"""Client bootstrap model."""
# client type, currently supports "textgen" and "automatic1111"
client_type: ClientType
# unique client identifier
uid: str
# connection name
name: str
# connection information for the client
# REST api url
api_url: str
# service name (for example runpod)
service_name: str
class register_list:
def __init__(self, service_name:str):
def __init__(self, service_name: str):
self.service_name = service_name
def __call__(self, func):
LISTS[self.service_name] = func
return func
async def list_all(exclude_urls: list[str] = list()):
"""
Return a list of client bootstrap objects.
"""
for service_name, func in LISTS.items():
async for item in func():
if item.api_url not in exclude_urls:
yield item.dict()
yield item.dict()

View File

@@ -3,19 +3,20 @@ Context managers for various client-side operations.
"""
from contextvars import ContextVar
from pydantic import BaseModel, Field
from copy import deepcopy
import structlog
from pydantic import BaseModel, Field
__all__ = [
'context_data',
'client_context_attribute',
'ContextModel',
"context_data",
"client_context_attribute",
"ContextModel",
]
log = structlog.get_logger()
def model_to_dict_without_defaults(model_instance):
model_dict = model_instance.dict()
for field_name, field in model_instance.__class__.__fields__.items():
@@ -23,20 +24,25 @@ def model_to_dict_without_defaults(model_instance):
del model_dict[field_name]
return model_dict
class ConversationContext(BaseModel):
talking_character: str = None
other_characters: list[str] = Field(default_factory=list)
class ContextModel(BaseModel):
"""
Pydantic model for the context data.
"""
nuke_repetition: float = Field(0.0, ge=0.0, le=3.0)
conversation: ConversationContext = Field(default_factory=ConversationContext)
length: int = 96
# Define the context variable as an empty dictionary
context_data = ContextVar('context_data', default=ContextModel().model_dump())
context_data = ContextVar("context_data", default=ContextModel().model_dump())
def client_context_attribute(name, default=None):
"""
@@ -47,6 +53,7 @@ def client_context_attribute(name, default=None):
# Return the value of the key if it exists, otherwise return the default value
return data.get(name, default)
def set_client_context_attribute(name, value):
"""
Set the value of the context variable `context_data` for the given key.
@@ -55,7 +62,8 @@ def set_client_context_attribute(name, value):
data = context_data.get()
# Set the value of the key
data[name] = value
def set_conversation_context_attribute(name, value):
"""
Set the value of the context variable `context_data.conversation` for the given key.
@@ -65,6 +73,7 @@ def set_conversation_context_attribute(name, value):
# Set the value of the key
data["conversation"][name] = value
class ClientContext:
"""
A context manager to set values to the context variable `context_data`.
@@ -82,10 +91,10 @@ class ClientContext:
Set the key-value pairs to the context variable `context_data` when entering the context.
"""
# Get the current context data
data = deepcopy(context_data.get()) if context_data.get() else {}
data.update(self.values)
# Update the context data
self.token = context_data.set(data)
@@ -93,5 +102,5 @@ class ClientContext:
"""
Reset the context variable `context_data` to its previous values when exiting the context.
"""
context_data.reset(self.token)

View File

@@ -1,16 +1,16 @@
import asyncio
import random
import json
import logging
import random
from abc import ABC, abstractmethod
from typing import Callable, Union
import requests
import talemate.client.system_prompts as system_prompts
import talemate.util as util
from talemate.client.registry import register
import talemate.client.system_prompts as system_prompts
from talemate.client.textgenwebui import RESTTaleMateClient
from talemate.emit import Emission, emit
# NOT IMPLEMENTED AT THIS POINT
# NOT IMPLEMENTED AT THIS POINT

View File

@@ -1,65 +1,62 @@
import pydantic
from openai import AsyncOpenAI
from talemate.client.base import ClientBase
from talemate.client.registry import register
from openai import AsyncOpenAI
class Defaults(pydantic.BaseModel):
api_url:str = "http://localhost:1234"
api_url: str = "http://localhost:1234"
max_token_length: int = 4096
@register()
class LMStudioClient(ClientBase):
client_type = "lmstudio"
conversation_retries = 5
class Meta(ClientBase.Meta):
name_prefix:str = "LMStudio"
title:str = "LMStudio"
defaults:Defaults = Defaults()
name_prefix: str = "LMStudio"
title: str = "LMStudio"
defaults: Defaults = Defaults()
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters:dict, kind:str):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
model_name = await super().get_model_name()
# model name comes back as a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash
if model_name:
model_name = model_name.replace("\\", "/").split("/")[-1]
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
human_message = {'role': 'user', 'content': prompt.strip()}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
human_message = {"role": "user", "content": prompt.strip()}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
)
return response.choices[0].message.content
except Exception as e:
self.log.error("generate error", e=e)
return ""
return ""

View File

@@ -1,17 +1,23 @@
from jinja2 import Environment, FileSystemLoader
import os
import structlog
import shutil
import huggingface_hub
import tempfile
import huggingface_hub
import structlog
from jinja2 import Environment, FileSystemLoader
__all__ = ["model_prompt"]
BASE_TEMPLATE_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "templates", "llm-prompt"
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
"..",
"templates",
"llm-prompt",
)
# holds the default templates
# holds the default templates
STD_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "std")
# llm prompt templates provided by talemate
@@ -22,62 +28,73 @@ USER_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "user")
TEMPLATE_IDENTIFIERS = []
def register_template_identifier(cls):
TEMPLATE_IDENTIFIERS.append(cls)
return cls
log = structlog.get_logger("talemate.model_prompts")
class ModelPrompt:
"""
Will attempt to load an LLM prompt template based on the model name
If the model name is not found, it will default to the 'default' template
"""
template_map = {}
@property
def env(self):
if not hasattr(self, "_env"):
log.info("modal prompt", base_template_path=BASE_TEMPLATE_PATH)
self._env = Environment(loader=FileSystemLoader([
USER_TEMPLATE_PATH,
TALEMATE_TEMPLATE_PATH,
]))
self._env = Environment(
loader=FileSystemLoader(
[
USER_TEMPLATE_PATH,
TALEMATE_TEMPLATE_PATH,
]
)
)
return self._env
@property
def std_templates(self) -> list[str]:
env = Environment(loader=FileSystemLoader(STD_TEMPLATE_PATH))
return sorted(env.list_templates())
def __call__(self, model_name:str, system_message:str, prompt:str):
def __call__(self, model_name: str, system_message: str, prompt: str):
template, template_file = self.get_template(model_name)
if not template:
template_file = "default.jinja2"
template = self.env.get_template(template_file)
if "<|BOT|>" in prompt:
user_message, coercion_message = prompt.split("<|BOT|>", 1)
else:
user_message = prompt
coercion_message = ""
return template.render({
"system_message": system_message,
"prompt": prompt,
"user_message": user_message,
"coercion_message": coercion_message,
"set_response" : self.set_response
}), template_file
def set_response(self, prompt:str, response_str:str):
return (
template.render(
{
"system_message": system_message,
"prompt": prompt,
"user_message": user_message,
"coercion_message": coercion_message,
"set_response": self.set_response,
}
),
template_file,
)
def set_response(self, prompt: str, response_str: str):
prompt = prompt.strip("\n").strip()
if "<|BOT|>" in prompt:
if "\n<|BOT|>" in prompt:
prompt = prompt.replace("\n<|BOT|>", response_str)
@@ -85,17 +102,17 @@ class ModelPrompt:
prompt = prompt.replace("<|BOT|>", response_str)
else:
prompt = prompt.rstrip("\n") + response_str
return prompt
def get_template(self, model_name:str):
def get_template(self, model_name: str):
"""
Will attempt to load an LLM prompt template - this supports
partial filename matching on the template file name.
"""
matches = []
# Iterate over all templates in the loader's directory
for template_name in self.env.list_templates():
# strip extension
@@ -103,56 +120,58 @@ class ModelPrompt:
# Check if the model name is in the template filename
if template_name_match.lower() in model_name.lower():
matches.append(template_name)
# If there are no matches, return None
if not matches:
return None, None
# If there is only one match, return it
if len(matches) == 1:
return self.env.get_template(matches[0]), matches[0]
# If there are multiple matches, return the one with the longest name
sorted_matches = sorted(matches, key=lambda x: len(x), reverse=True)
return self.env.get_template(sorted_matches[0]), sorted_matches[0]
def create_user_override(self, template_name:str, model_name:str):
def create_user_override(self, template_name: str, model_name: str):
"""
Will copy STD_TEMPLATE_PATH/template_name to USER_TEMPLATE_PATH/model_name.jinja2
"""
template_name = template_name.split(".jinja2")[0]
shutil.copyfile(
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2"),
)
return os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
def query_hf_for_prompt_template_suggestion(self, model_name:str):
def query_hf_for_prompt_template_suggestion(self, model_name: str):
print("query_hf_for_prompt_template_suggestion", model_name)
api = huggingface_hub.HfApi()
try:
author, model_name = model_name.split("_", 1)
except ValueError:
return None
models = list(api.list_models(
filter=huggingface_hub.ModelFilter(model_name=model_name, author=author)
))
models = list(
api.list_models(
filter=huggingface_hub.ModelFilter(model_name=model_name, author=author)
)
)
if not models:
return None
model = models[0]
repo_id = f"{author}/{model_name}"
with tempfile.TemporaryDirectory() as tmpdir:
readme_path = huggingface_hub.hf_hub_download(repo_id=repo_id, filename="README.md", cache_dir=tmpdir)
readme_path = huggingface_hub.hf_hub_download(
repo_id=repo_id, filename="README.md", cache_dir=tmpdir
)
if not readme_path:
return None
with open(readme_path) as f:
@@ -163,24 +182,27 @@ class ModelPrompt:
return f"{identifier.template_str}.jinja2"
model_prompt = ModelPrompt()
class TemplateIdentifier:
def __call__(self, content:str):
def __call__(self, content: str):
return False
@register_template_identifier
class Llama2Identifier(TemplateIdentifier):
template_str = "Llama2"
def __call__(self, content:str):
def __call__(self, content: str):
return "[INST]" in content and "[/INST]" in content
@register_template_identifier
class ChatMLIdentifier(TemplateIdentifier):
template_str = "ChatML"
def __call__(self, content:str):
def __call__(self, content: str):
"""
<|im_start|>system
{{ system_message }}<|im_end|>
@@ -189,7 +211,7 @@ class ChatMLIdentifier(TemplateIdentifier):
<|im_start|>assistant
{{ coercion_message }}
"""
return (
"<|im_start|>system" in content
and "<|im_end|>" in content
@@ -197,20 +219,24 @@ class ChatMLIdentifier(TemplateIdentifier):
and "<|im_start|>assistant" in content
)
@register_template_identifier
class InstructionInputResponseIdentifier(TemplateIdentifier):
template_str = "InstructionInputResponse"
def __call__(self, content:str):
def __call__(self, content: str):
return (
"### Instruction:" in content
and "### Input:" in content
and "### Response:" in content
)
@register_template_identifier
class AlpacaIdentifier(TemplateIdentifier):
template_str = "Alpaca"
def __call__(self, content:str):
def __call__(self, content: str):
"""
{{ system_message }}
@@ -220,20 +246,19 @@ class AlpacaIdentifier(TemplateIdentifier):
### Response:
{{ coercion_message }}
"""
return (
"### Instruction:" in content
and "### Response:" in content
)
return "### Instruction:" in content and "### Response:" in content
@register_template_identifier
class OpenChatIdentifier(TemplateIdentifier):
template_str = "OpenChat"
def __call__(self, content:str):
def __call__(self, content: str):
"""
GPT4 Correct System: {{ system_message }}<|end_of_turn|>GPT4 Correct User: {{ user_message }}<|end_of_turn|>GPT4 Correct Assistant: {{ coercion_message }}
"""
return (
"<|end_of_turn|>" in content
and "GPT4 Correct System:" in content
@@ -241,54 +266,51 @@ class OpenChatIdentifier(TemplateIdentifier):
and "GPT4 Correct Assistant:" in content
)
@register_template_identifier
class VicunaIdentifier(TemplateIdentifier):
template_str = "Vicuna"
def __call__(self, content:str):
def __call__(self, content: str):
"""
SYSTEM: {{ system_message }}
USER: {{ user_message }}
ASSISTANT: {{ coercion_message }}
"""
return (
"SYSTEM:" in content
and "USER:" in content
and "ASSISTANT:" in content
)
return "SYSTEM:" in content and "USER:" in content and "ASSISTANT:" in content
@register_template_identifier
class USER_ASSISTANTIdentifier(TemplateIdentifier):
template_str = "USER_ASSISTANT"
def __call__(self, content:str):
def __call__(self, content: str):
"""
USER: {{ system_message }} {{ user_message }} ASSISTANT: {{ coercion_message }}
"""
return (
"USER:" in content
and "ASSISTANT:" in content
)
return "USER:" in content and "ASSISTANT:" in content
@register_template_identifier
class UserAssistantIdentifier(TemplateIdentifier):
template_str = "UserAssistant"
def __call__(self, content:str):
def __call__(self, content: str):
"""
User: {{ system_message }} {{ user_message }}
Assistant: {{ coercion_message }}
"""
return (
"User:" in content
and "Assistant:" in content
)
return "User:" in content and "Assistant:" in content
@register_template_identifier
class ZephyrIdentifier(TemplateIdentifier):
template_str = "Zephyr"
def __call__(self, content:str):
def __call__(self, content: str):
"""
<|system|>
{{ system_message }}</s>
@@ -297,9 +319,9 @@ class ZephyrIdentifier(TemplateIdentifier):
<|assistant|>
{{ coercion_message }}
"""
return (
"<|system|>" in content
and "<|user|>" in content
and "<|assistant|>" in content
)
)

View File

@@ -1,21 +1,23 @@
import json
import pydantic
import structlog
import tiktoken
from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction
from talemate.client.registry import register
from talemate.config import load_config
from talemate.emit import emit
from talemate.emit.signals import handlers
from talemate.config import load_config
import structlog
import tiktoken
__all__ = [
"OpenAIClient",
]
log = structlog.get_logger("talemate")
def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
@@ -66,9 +68,11 @@ def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
class Defaults(pydantic.BaseModel):
max_token_length:int = 16384
model:str = "gpt-4-turbo-preview"
max_token_length: int = 16384
model: str = "gpt-4-turbo-preview"
@register()
class OpenAIClient(ClientBase):
@@ -79,13 +83,13 @@ class OpenAIClient(ClientBase):
client_type = "openai"
conversation_retries = 0
auto_break_repetition_enabled = False
class Meta(ClientBase.Meta):
name_prefix:str = "OpenAI"
title:str = "OpenAI"
manual_model:bool = True
manual_model_choices:list[str] = [
"gpt-3.5-turbo",
name_prefix: str = "OpenAI"
title: str = "OpenAI"
manual_model: bool = True
manual_model_choices: list[str] = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-1106-preview",
@@ -93,21 +97,19 @@ class OpenAIClient(ClientBase):
"gpt-4-turbo-preview",
]
requires_prompt_template: bool = False
defaults:Defaults = Defaults()
defaults: Defaults = Defaults()
def __init__(self, model="gpt-4-turbo-preview", **kwargs):
self.model_name = model
self.api_key_status = None
self.config = load_config()
super().__init__(**kwargs)
handlers["config_saved"].connect(self.on_config_saved)
@property
def openai_api_key(self):
return self.config.get("openai",{}).get("api_key")
return self.config.get("openai", {}).get("api_key")
def emit_status(self, processing: bool = None):
error_action = None
@@ -127,13 +129,13 @@ class OpenAIClient(ClientBase):
arguments=[
"application",
"openai_api",
]
],
)
if not self.model_name:
status = "error"
model_name = "No model loaded"
self.current_status = status
emit(
@@ -145,25 +147,24 @@ class OpenAIClient(ClientBase):
data={
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
}
},
)
def set_client(self, max_token_length:int=None):
def set_client(self, max_token_length: int = None):
if not self.openai_api_key:
self.client = AsyncOpenAI(api_key="sk-1111")
log.error("No OpenAI API key set")
if self.api_key_status:
self.api_key_status = False
emit('request_client_status')
emit('request_agent_status')
emit("request_client_status")
emit("request_agent_status")
return
if not self.model_name:
self.model_name = "gpt-3.5-turbo-16k"
model = self.model_name
self.client = AsyncOpenAI(api_key=self.openai_api_key)
if model == "gpt-3.5-turbo":
self.max_token_length = min(max_token_length or 4096, 4096)
@@ -175,16 +176,20 @@ class OpenAIClient(ClientBase):
self.max_token_length = min(max_token_length or 128000, 128000)
else:
self.max_token_length = max_token_length or 2048
if not self.api_key_status:
if self.api_key_status is False:
emit('request_client_status')
emit('request_agent_status')
emit("request_client_status")
emit("request_agent_status")
self.api_key_status = True
log.info("openai set client", max_token_length=self.max_token_length, provided_max_token_length=max_token_length, model=model)
log.info(
"openai set client",
max_token_length=self.max_token_length,
provided_max_token_length=max_token_length,
model=model,
)
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
@@ -203,39 +208,37 @@ class OpenAIClient(ClientBase):
async def status(self):
self.emit_status()
def prompt_template(self, system_message:str, prompt:str):
def prompt_template(self, system_message: str, prompt: str):
# only gpt-4-1106-preview supports json_object response coersion
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
else:
prompt = prompt.replace("<|BOT|>", "")
return prompt
def tune_prompt_parameters(self, parameters:dict, kind:str):
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
if not self.openai_api_key:
raise Exception("No OpenAI API key set")
# only gpt-4-* supports enforcing json object
supports_json_object = self.model_name.startswith("gpt-4-")
right = None
@@ -246,26 +249,28 @@ class OpenAIClient(ClientBase):
parameters["response_format"] = {"type": "json_object"}
except (IndexError, ValueError):
pass
human_message = {'role': 'user', 'content': prompt.strip()}
system_message = {'role': 'system', 'content': self.get_system_message(kind)}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
human_message = {"role": "user", "content": prompt.strip()}
system_message = {"role": "system", "content": self.get_system_message(kind)}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[system_message, human_message], **parameters
model=self.model_name,
messages=[system_message, human_message],
**parameters,
)
response = response.choices[0].message.content
if right and response.startswith(right):
response = response[len(right):].strip()
response = response[len(right) :].strip()
return response
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
emit("status", message="OpenAI API: Permission Denied", status="error")
return ""
except Exception as e:
raise
raise

View File

@@ -1,32 +1,33 @@
import pydantic
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
from talemate.client.base import ClientBase
from talemate.client.registry import register
from openai import AsyncOpenAI, PermissionDeniedError, NotFoundError
from talemate.emit import emit
EXPERIMENTAL_DESCRIPTION = """Use this client if you want to connect to a service implementing an OpenAI-compatible API. Success is going to depend on the level of compatibility. Use the actual OpenAI client if you want to connect to OpenAI's API."""
class Defaults(pydantic.BaseModel):
api_url:str = "http://localhost:5000"
api_key:str = ""
max_token_length:int = 4096
model:str = ""
api_url: str = "http://localhost:5000"
api_key: str = ""
max_token_length: int = 4096
model: str = ""
@register()
class OpenAICompatibleClient(ClientBase):
client_type = "openai_compat"
conversation_retries = 5
class Meta(ClientBase.Meta):
title:str = "OpenAI Compatible API"
name_prefix:str = "OpenAI Compatible API"
experimental:str = EXPERIMENTAL_DESCRIPTION
enable_api_auth:bool = True
manual_model:bool = True
defaults:Defaults = Defaults()
title: str = "OpenAI Compatible API"
name_prefix: str = "OpenAI Compatible API"
experimental: str = EXPERIMENTAL_DESCRIPTION
enable_api_auth: bool = True
manual_model: bool = True
defaults: Defaults = Defaults()
def __init__(self, model=None, **kwargs):
self.model_name = model
super().__init__(**kwargs)
@@ -37,23 +38,23 @@ class OpenAICompatibleClient(ClientBase):
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key")
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key=self.api_key)
self.model_name = kwargs.get("model") or kwargs.get("model_name") or self.model_name
def tune_prompt_parameters(self, parameters:dict, kind:str):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key=self.api_key)
self.model_name = (
kwargs.get("model") or kwargs.get("model_name") or self.model_name
)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
try:
model_name = await super().get_model_name()
except NotFoundError as e:
@@ -65,29 +66,28 @@ class OpenAICompatibleClient(ClientBase):
# model name may be a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash
is_filepath = "/" in model_name
is_filepath_windows = "\\" in model_name
if is_filepath or is_filepath_windows:
model_name = model_name.replace("\\", "/").split("/")[-1]
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
human_message = {'role': 'user', 'content': prompt.strip()}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
human_message = {"role": "user", "content": prompt.strip()}
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
)
return response.choices[0].message.content
except PermissionDeniedError as e:
self.log.error("generate error", e=e)
@@ -95,7 +95,9 @@ class OpenAICompatibleClient(ClientBase):
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message="Error during generation (check logs)", status="error")
emit(
"status", message="Error during generation (check logs)", status="error"
)
return ""
def reconfigure(self, **kwargs):
@@ -107,5 +109,5 @@ class OpenAICompatibleClient(ClientBase):
self.max_token_length = kwargs["max_token_length"]
if "api_key" in kwargs:
self.api_auth = kwargs["api_key"]
self.set_client(**kwargs)
self.set_client(**kwargs)

View File

@@ -28,18 +28,18 @@ PRESET_TALEMATE_CREATOR = {
}
PRESET_LLAMA_PRECISE = {
'temperature': 0.7,
'top_p': 0.1,
'top_k': 40,
'repetition_penalty': 1.18,
"temperature": 0.7,
"top_p": 0.1,
"top_k": 40,
"repetition_penalty": 1.18,
}
PRESET_DIVINE_INTELLECT = {
'temperature': 1.31,
'top_p': 0.14,
'top_k': 49,
"temperature": 1.31,
"top_p": 0.14,
"top_k": 49,
"repetition_penalty_range": 1024,
'repetition_penalty': 1.17,
"repetition_penalty": 1.17,
}
PRESET_SIMPLE_1 = {
@@ -49,7 +49,8 @@ PRESET_SIMPLE_1 = {
"repetition_penalty": 1.15,
}
def configure(config:dict, kind:str, total_budget:int):
def configure(config: dict, kind: str, total_budget: int):
"""
Sets the config based on the kind of text to generate.
"""
@@ -57,19 +58,22 @@ def configure(config:dict, kind:str, total_budget:int):
set_max_tokens(config, kind, total_budget)
return config
def set_max_tokens(config:dict, kind:str, total_budget:int):
def set_max_tokens(config: dict, kind: str, total_budget: int):
"""
Sets the max_tokens in the config based on the kind of text to generate.
"""
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
return config
def set_preset(config:dict, kind:str):
def set_preset(config: dict, kind: str):
"""
Sets the preset in the config based on the kind of text to generate.
"""
config.update(preset_for_kind(kind))
def preset_for_kind(kind: str):
if kind == "conversation":
return PRESET_TALEMATE_CONVERSATION
@@ -104,9 +108,13 @@ def preset_for_kind(kind: str):
elif kind == "director":
return PRESET_SIMPLE_1
elif kind == "director_short":
return PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
return (
PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
)
elif kind == "director_yesno":
return PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
return (
PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
)
elif kind == "edit_dialogue":
return PRESET_DIVINE_INTELLECT
elif kind == "edit_add_detail":
@@ -116,6 +124,7 @@ def preset_for_kind(kind: str):
else:
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
def max_tokens_for_kind(kind: str, total_budget: int):
if kind == "conversation":
return 75 # Example value, adjust as needed
@@ -142,15 +151,23 @@ def max_tokens_for_kind(kind: str, total_budget: int):
elif kind == "story":
return 300 # Example value, adjust as needed
elif kind == "create":
return min(1024, int(total_budget * 0.35)) # Example calculation, adjust as needed
return min(
1024, int(total_budget * 0.35)
) # Example calculation, adjust as needed
elif kind == "create_concise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
return min(
400, int(total_budget * 0.25)
) # Example calculation, adjust as needed
elif kind == "create_precise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
return min(
400, int(total_budget * 0.25)
) # Example calculation, adjust as needed
elif kind == "create_short":
return 25
elif kind == "director":
return min(192, int(total_budget * 0.25)) # Example calculation, adjust as needed
return min(
192, int(total_budget * 0.25)
) # Example calculation, adjust as needed
elif kind == "director_short":
return 25 # Example value, adjust as needed
elif kind == "director_yesno":
@@ -162,4 +179,4 @@ def max_tokens_for_kind(kind: str, total_budget: int):
elif kind == "edit_fix_exposition":
return 1024 # Example value, adjust as needed
else:
return 150 # Default value if none of the kinds match
return 150 # Default value if none of the kinds match

View File

@@ -3,17 +3,17 @@ Retrieve pod information from the server which can then be used to bootstrap tal
connection for the pod. This is a simple wrapper around the runpod module.
"""
import asyncio
import json
import os
import dotenv
import runpod
import os
import json
import asyncio
from .bootstrap import ClientBootstrap, ClientType, register_list
import structlog
from talemate.config import load_config
import structlog
from .bootstrap import ClientBootstrap, ClientType, register_list
log = structlog.get_logger("talemate.client.runpod")
@@ -21,73 +21,75 @@ dotenv.load_dotenv()
runpod.api_key = load_config().get("runpod", {}).get("api_key", "")
def is_textgen_pod(pod):
name = pod["name"].lower()
if "textgen" in name or "thebloke llms" in name:
return True
return False
async def _async_get_pods():
"""
asyncio wrapper around get_pods.
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, runpod.get_pods)
async def get_textgen_pods():
"""
Return a list of text generation pods.
"""
if not runpod.api_key:
return
for pod in await _async_get_pods():
if not pod["desiredStatus"] == "RUNNING":
continue
if is_textgen_pod(pod):
yield pod
async def get_automatic1111_pods():
"""
Return a list of automatic1111 pods.
"""
if not runpod.api_key:
return
for pod in await _async_get_pods():
if not pod["desiredStatus"] == "RUNNING":
continue
if "automatic1111" in pod["name"].lower():
yield pod
def _client_bootstrap(client_type: ClientType, pod):
"""
Return a client bootstrap object for the given client type and pod.
"""
id = pod["id"]
if client_type == ClientType.textgen:
api_url = f"https://{id}-5000.proxy.runpod.net"
elif client_type == ClientType.automatic1111:
api_url = f"https://{id}-5000.proxy.runpod.net"
return ClientBootstrap(
client_type=client_type,
uid=pod["id"],
name=pod["name"],
api_url=api_url,
service_name="runpod"
service_name="runpod",
)
@register_list("runpod")
async def client_bootstrap_list():
@@ -97,13 +99,13 @@ async def client_bootstrap_list():
textgen_pods = []
async for pod in get_textgen_pods():
textgen_pods.append(pod)
automatic1111_pods = []
async for pod in get_automatic1111_pods():
automatic1111_pods.append(pod)
for pod in textgen_pods:
yield _client_bootstrap(ClientType.textgen, pod)
for pod in automatic1111_pods:
yield _client_bootstrap(ClientType.automatic1111, pod)
yield _client_bootstrap(ClientType.automatic1111, pod)

View File

@@ -18,4 +18,4 @@ EDITOR = str(Prompt.get("editor.system"))
WORLD_STATE = str(Prompt.get("world_state.system-analyst"))
SUMMARIZE = str(Prompt.get("summarizer.system"))
SUMMARIZE = str(Prompt.get("summarizer.system"))

View File

@@ -1,77 +1,94 @@
from talemate.client.base import ClientBase, STOPPING_STRINGS
from talemate.client.registry import register
from openai import AsyncOpenAI
import httpx
import random
import httpx
import structlog
from openai import AsyncOpenAI
from talemate.client.base import STOPPING_STRINGS, ClientBase
from talemate.client.registry import register
log = structlog.get_logger("talemate.client.textgenwebui")
@register()
class TextGeneratorWebuiClient(ClientBase):
client_type = "textgenwebui"
class Meta(ClientBase.Meta):
name_prefix:str = "TextGenWebUI"
title:str = "Text-Generation-WebUI (ooba)"
def tune_prompt_parameters(self, parameters:dict, kind:str):
name_prefix: str = "TextGenWebUI"
title: str = "Text-Generation-WebUI (ooba)"
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get(
"extra_stopping_strings", []
)
# is this needed?
parameters["max_new_tokens"] = parameters["max_tokens"]
parameters["stop"] = parameters["stopping_strings"]
# Half temperature on -Yi- models
if self.model_name and "-yi-" in self.model_name.lower() and parameters["temperature"] > 0.1:
if (
self.model_name
and "-yi-" in self.model_name.lower()
and parameters["temperature"] > 0.1
):
parameters["temperature"] = parameters["temperature"] / 2
log.debug("halfing temperature for -yi- model", temperature=parameters["temperature"])
log.debug(
"halfing temperature for -yi- model",
temperature=parameters["temperature"],
)
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
async def get_model_name(self):
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_url}/v1/internal/model/info", timeout=2)
response = await client.get(
f"{self.api_url}/v1/internal/model/info", timeout=2
)
if response.status_code == 404:
raise Exception("Could not find model info (wrong api version?)")
response_data = response.json()
model_name = response_data.get("model_name")
if model_name == "None":
model_name = None
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
headers = {}
headers["Content-Type"] = "application/json"
parameters["prompt"] = prompt.strip(" ")
async with httpx.AsyncClient() as client:
response = await client.post(f"{self.api_url}/v1/completions", json=parameters, timeout=None, headers=headers)
response = await client.post(
f"{self.api_url}/v1/completions",
json=parameters,
timeout=None,
headers=headers,
)
response_data = response.json()
return response_data["choices"][0]["text"]
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
prompt_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
prompt_config["repetition_penalty"] = random.uniform(
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
)

View File

@@ -1,32 +1,33 @@
import copy
import random
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
def jiggle_randomness(prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
copied_config = copy.deepcopy(prompt_config)
min_offset = offset * 0.3
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
copied_config["repetition_penalty"] = random.uniform(
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
)
return copied_config
def jiggle_enabled_for(kind:str):
def jiggle_enabled_for(kind: str):
if kind in ["conversation", "story"]:
return True
if kind.startswith("narrate"):
return True
return False
return False

View File

@@ -12,17 +12,17 @@ from .cmd_memget import CmdMemget
from .cmd_memset import CmdMemset
from .cmd_narrate import *
from .cmd_rebuild_archive import CmdRebuildArchive
from .cmd_remove_character import CmdRemoveCharacter
from .cmd_rename import CmdRename
from .cmd_rerun import *
from .cmd_reset import CmdReset
from .cmd_rm import CmdRm
from .cmd_remove_character import CmdRemoveCharacter
from .cmd_run_helios_test import CmdHeliosTest
from .cmd_save import CmdSave
from .cmd_save_as import CmdSaveAs
from .cmd_save_characters import CmdSaveCharacters
from .cmd_setenv import CmdSetEnvironmentToScene, CmdSetEnvironmentToCreative
from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene
from .cmd_time_util import *
from .cmd_tts import *
from .cmd_world_state import *
from .cmd_run_helios_test import CmdHeliosTest
from .manager import Manager
from .manager import Manager

View File

@@ -41,7 +41,7 @@ class TalemateCommand(Emitter, ABC):
raise NotImplementedError(
"TalemateCommand.run() must be implemented by subclass"
)
@property
def verbose_name(self):
if self.label:
@@ -50,6 +50,6 @@ class TalemateCommand(Emitter, ABC):
def command_start(self):
emit("command_status", self.verbose_name, status="started")
def command_end(self):
emit("command_status", self.verbose_name, status="ended")
emit("command_status", self.verbose_name, status="ended")

View File

@@ -1,9 +1,9 @@
import structlog
from talemate.character import activate_character, deactivate_character
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import wait_for_input, emit
from talemate.character import deactivate_character, activate_character
from talemate.emit import emit, wait_for_input
from talemate.instance import get_agent
log = structlog.get_logger("talemate.cmd.characters")
@@ -13,6 +13,7 @@ __all__ = [
"CmdActivateCharacter",
]
@register
class CmdDeactivateCharacter(TalemateCommand):
"""
@@ -22,61 +23,77 @@ class CmdDeactivateCharacter(TalemateCommand):
name = "character_deactivate"
description = "Will deactivate a character"
aliases = ["char_d"]
label = "Character exit"
async def run(self):
narrator = get_agent("narrator")
world_state = get_agent("world_state")
characters = list([character.name for character in self.scene.get_npc_characters()])
characters = list(
[character.name for character in self.scene.get_npc_characters()]
)
if not characters:
emit("status", message="No characters found", status="error")
return True
if self.args:
character_name = self.args[0]
else:
character_name = await wait_for_input("Which character do you want to deactivate?", data={
"input_type": "select",
"choices": characters,
})
character_name = await wait_for_input(
"Which character do you want to deactivate?",
data={
"input_type": "select",
"choices": characters,
},
)
if not character_name:
emit("status", message="No character selected", status="error")
return True
never_narrate = len(self.args) > 1 and self.args[1] == "no"
if not never_narrate:
is_present = await world_state.is_character_present(character_name)
is_leaving = await world_state.is_character_leaving(character_name)
log.debug("deactivate_character", character_name=character_name, is_present=is_present, is_leaving=is_leaving, never_narrate=never_narrate)
log.debug(
"deactivate_character",
character_name=character_name,
is_present=is_present,
is_leaving=is_leaving,
never_narrate=never_narrate,
)
else:
is_present = False
is_leaving = True
log.debug("deactivate_character", character_name=character_name, never_narrate=never_narrate)
log.debug(
"deactivate_character",
character_name=character_name,
never_narrate=never_narrate,
)
if is_present and not is_leaving and not never_narrate:
direction = await wait_for_input(f"How does {character_name} exit the scene? (leave blank for AI to decide)")
direction = await wait_for_input(
f"How does {character_name} exit the scene? (leave blank for AI to decide)"
)
message = await narrator.action_to_narration(
"narrate_character_exit",
self.scene.get_character(character_name),
direction = direction,
direction=direction,
)
self.narrator_message(message)
await deactivate_character(self.scene, character_name)
await deactivate_character(self.scene, character_name)
emit("status", message=f"Deactivated {character_name}", status="success")
self.scene.emit_status()
self.scene.world_state.emit()
self.scene.world_state.emit()
return True
@register
class CmdActivateCharacter(TalemateCommand):
"""
@@ -86,57 +103,70 @@ class CmdActivateCharacter(TalemateCommand):
name = "character_activate"
description = "Will activate a character"
aliases = ["char_a"]
label = "Character enter"
async def run(self):
world_state = get_agent("world_state")
narrator = get_agent("narrator")
characters = list(self.scene.inactive_characters.keys())
if not characters:
emit("status", message="No characters found", status="error")
return True
if self.args:
character_name = self.args[0]
if character_name not in characters:
emit("status", message="Character not found", status="error")
return True
else:
character_name = await wait_for_input("Which character do you want to activate?", data={
"input_type": "select",
"choices": characters,
})
character_name = await wait_for_input(
"Which character do you want to activate?",
data={
"input_type": "select",
"choices": characters,
},
)
if not character_name:
emit("status", message="No character selected", status="error")
return True
never_narrate = len(self.args) > 1 and self.args[1] == "no"
if not never_narrate:
is_present = await world_state.is_character_present(character_name)
log.debug("activate_character", character_name=character_name, is_present=is_present, never_narrate=never_narrate)
log.debug(
"activate_character",
character_name=character_name,
is_present=is_present,
never_narrate=never_narrate,
)
else:
is_present = True
log.debug("activate_character", character_name=character_name, never_narrate=never_narrate)
log.debug(
"activate_character",
character_name=character_name,
never_narrate=never_narrate,
)
await activate_character(self.scene, character_name)
if not is_present and not never_narrate:
direction = await wait_for_input(f"How does {character_name} enter the scene? (leave blank for AI to decide)")
direction = await wait_for_input(
f"How does {character_name} enter the scene? (leave blank for AI to decide)"
)
message = await narrator.action_to_narration(
"narrate_character_entry",
self.scene.get_character(character_name),
direction = direction,
direction=direction,
)
self.narrator_message(message)
emit("status", message=f"Activated {character_name}", status="success")
self.scene.emit_status()
self.scene.world_state.emit()
return True
return True

View File

@@ -12,6 +12,7 @@ __all__ = [
"CmdRunAutomatic",
]
@register
class CmdDebugOn(TalemateCommand):
"""
@@ -26,6 +27,7 @@ class CmdDebugOn(TalemateCommand):
logging.getLogger().setLevel(logging.DEBUG)
await asyncio.sleep(0)
@register
class CmdDebugOff(TalemateCommand):
"""
@@ -46,66 +48,64 @@ class CmdPromptChangeSectioning(TalemateCommand):
"""
Command class for the '_prompt_change_sectioning' command
"""
name = "_prompt_change_sectioning"
description = "Change the sectioning handler for the prompt system"
aliases = []
async def run(self):
if not self.args:
self.emit("system", "You must specify a sectioning handler")
return
handler_name = self.args[0]
set_default_sectioning_handler(handler_name)
self.emit("system", f"Sectioning handler set to {handler_name}")
await asyncio.sleep(0)
@register
class CmdRunAutomatic(TalemateCommand):
"""
Command class for the 'run_automatic' command
"""
name = "run_automatic"
description = "Will make the player character AI controlled for n turns"
aliases = ["auto"]
async def run(self):
if self.args:
turns = int(self.args[0])
else:
turns = 10
self.emit("system", f"Making player character AI controlled for {turns} turns")
self.scene.get_player_character().actor.ai_controlled = turns
@register
class CmdLongTermMemoryStats(TalemateCommand):
"""
Command class for the 'long_term_memory_stats' command
"""
name = "long_term_memory_stats"
description = "Show stats for the long term memory"
aliases = ["ltm_stats"]
async def run(self):
memory = self.scene.get_helper("memory").agent
count = await memory.count()
db_name = memory.db_name
self.emit("system", f"Long term memory for {self.scene.name} has {count} entries in the {db_name} database")
self.emit(
"system",
f"Long term memory for {self.scene.name} has {count} entries in the {db_name} database",
)
@register
@@ -113,35 +113,34 @@ class CmdLongTermMemoryReset(TalemateCommand):
"""
Command class for the 'long_term_memory_reset' command
"""
name = "long_term_memory_reset"
description = "Reset the long term memory"
aliases = ["ltm_reset"]
async def run(self):
await self.scene.commit_to_memory()
self.emit("system", f"Long term memory for {self.scene.name} has been reset")
@register
class CmdSetContentContext(TalemateCommand):
"""
Command class for the 'set_content_context' command
"""
name = "set_content_context"
description = "Set the content context for the scene"
aliases = ["set_context"]
async def run(self):
if not self.args:
self.emit("system", "You must specify a context")
return
context = self.args[0]
self.scene.context = context
self.emit("system", f"Content context set to {context}")
self.emit("system", f"Content context set to {context}")

View File

@@ -1,9 +1,10 @@
import asyncio
import random
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.scene_message import DirectorMessage
from talemate.emit import wait_for_input
from talemate.scene_message import DirectorMessage
__all__ = [
"CmdAIDialogue",
@@ -11,6 +12,7 @@ __all__ = [
"CmdAIDialogueDirected",
]
@register
class CmdAIDialogue(TalemateCommand):
"""
@@ -23,15 +25,14 @@ class CmdAIDialogue(TalemateCommand):
async def run(self):
conversation_agent = self.scene.get_helper("conversation").agent
actor = None
# if there is only one npc in the scene, use that
if len(self.scene.npc_character_names) == 1:
actor = list(self.scene.get_npc_characters())[0].actor
else:
if conversation_agent.actions["natural_flow"].enabled:
await conversation_agent.apply_natural_flow(force=True, npcs_only=True)
character_name = self.scene.next_actor
@@ -41,83 +42,83 @@ class CmdAIDialogue(TalemateCommand):
else:
# randomly select an actor
actor = random.choice(list(self.scene.get_npc_characters())).actor
if not actor:
return
messages = await actor.talk()
self.scene.process_npc_dialogue(actor, messages)
@register
class CmdAIDialogueSelective(TalemateCommand):
"""
Command class for the 'ai_dialogue_selective' command
Will allow the player to select which npc dialogue will be generated
for
"""
name = "ai_dialogue_selective"
description = "Generate dialogue for an AI selected actor"
aliases = ["dlg_selective"]
async def run(self):
npc_name = self.args[0]
character = self.scene.get_character(npc_name)
if not character:
self.emit("system_message", message=f"Character not found: {npc_name}")
return
actor = character.actor
messages = await actor.talk()
self.scene.process_npc_dialogue(actor, messages)
@register
class CmdAIDialogueDirected(TalemateCommand):
"""
Command class for the 'ai_dialogue_directed' command
Will allow the player to select which npc dialogue will be generated
for
"""
name = "ai_dialogue_directed"
description = "Generate dialogue for an AI selected actor"
aliases = ["dlg_directed"]
async def run(self):
npc_name = self.args[0]
character = self.scene.get_character(npc_name)
if not character:
self.emit("system_message", message=f"Character not found: {npc_name}")
return
prefix = f"Director instructs {character.name}: \"To progress the scene, i want you to"
direction = await wait_for_input(prefix+"... (enter your instructions)")
direction = f"{prefix} {direction}\""
prefix = f'Director instructs {character.name}: "To progress the scene, i want you to'
direction = await wait_for_input(prefix + "... (enter your instructions)")
direction = f'{prefix} {direction}"'
director_message = DirectorMessage(direction, source=character.name)
self.emit("director", director_message, character=character)
self.scene.push_history(director_message)
actor = character.actor
messages = await actor.talk()
self.scene.process_npc_dialogue(actor, messages)
self.scene.process_npc_dialogue(actor, messages)

View File

@@ -1,8 +1,8 @@
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import wait_for_input, emit
from talemate.util import colored_text, wrap_text
from talemate.emit import emit, wait_for_input
from talemate.scene_message import DirectorMessage
from talemate.util import colored_text, wrap_text
@register
@@ -21,9 +21,9 @@ class CmdDirectorDirect(TalemateCommand):
if not director:
self.system_message("No director found")
return True
npc_count = self.scene.num_npc_characters()
if npc_count == 1:
character = list(self.scene.get_npc_characters())[0]
elif npc_count > 1:
@@ -36,17 +36,20 @@ class CmdDirectorDirect(TalemateCommand):
if not character:
self.system_message(f"Character not found: {name}")
return True
goal = await wait_for_input(f"Enter a new goal for the director to direct {character.name}")
goal = await wait_for_input(
f"Enter a new goal for the director to direct {character.name}"
)
if not goal.strip():
self.system_message("No goal specified")
return True
director.agent.actions["direct"].config["prompt"].value = goal
await director.agent.direct_character(character, goal)
@register
class CmdDirectorDirectWithOverride(CmdDirectorDirect):
"""
@@ -54,7 +57,9 @@ class CmdDirectorDirectWithOverride(CmdDirectorDirect):
"""
name = "director_with_goal"
description = "Calls a director to give directionts to a character (with goal specified)"
description = (
"Calls a director to give directionts to a character (with goal specified)"
)
aliases = ["direct_g"]
async def run(self):

View File

@@ -1,6 +1,7 @@
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
@register
class CmdMemget(TalemateCommand):
"""
@@ -16,4 +17,4 @@ class CmdMemget(TalemateCommand):
memories = self.scene.get_helper("memory").agent.get(query)
for memory in memories:
self.emit("narrator", memory["text"])
self.emit("narrator", memory["text"])

View File

@@ -2,9 +2,9 @@ import asyncio
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.util import colored_text, wrap_text
from talemate.scene_message import NarratorMessage
from talemate.emit import wait_for_input
from talemate.scene_message import NarratorMessage
from talemate.util import colored_text, wrap_text
__all__ = [
"CmdNarrate",
@@ -14,6 +14,7 @@ __all__ = [
"CmdNarrateC",
]
@register
class CmdNarrate(TalemateCommand):
"""
@@ -33,7 +34,7 @@ class CmdNarrate(TalemateCommand):
narration = await narrator.agent.narrate_scene()
message = NarratorMessage(narration, source="narrate_scene")
self.narrator_message(message)
self.scene.push_history(message)
@@ -58,17 +59,22 @@ class CmdNarrateQ(TalemateCommand):
if self.args:
query = self.args[0]
at_the_end = (self.args[1].lower() == "true") if len(self.args) > 1 else False
at_the_end = (
(self.args[1].lower() == "true") if len(self.args) > 1 else False
)
else:
query = await wait_for_input("Enter query: ")
at_the_end = False
narration = await narrator.agent.narrate_query(query, at_the_end=at_the_end)
message = NarratorMessage(narration, source=f"narrate_query:{query.replace(':', '-')}")
message = NarratorMessage(
narration, source=f"narrate_query:{query.replace(':', '-')}"
)
self.narrator_message(message)
self.scene.push_history(message)
@register
class CmdNarrateProgress(TalemateCommand):
"""
@@ -89,10 +95,11 @@ class CmdNarrateProgress(TalemateCommand):
narration = await narrator.agent.progress_story()
message = NarratorMessage(narration, source="progress_story")
self.narrator_message(message)
self.scene.push_history(message)
@register
class CmdNarrateProgressDirected(TalemateCommand):
"""
@@ -105,16 +112,17 @@ class CmdNarrateProgressDirected(TalemateCommand):
async def run(self):
narrator = self.scene.get_helper("narrator")
direction = await wait_for_input("Enter direction for the narrator: ")
narration = await narrator.agent.progress_story(narrative_direction=direction)
message = NarratorMessage(narration, source=f"progress_story:{direction}")
self.narrator_message(message)
self.scene.push_history(message)
@register
class CmdNarrateC(TalemateCommand):
"""
@@ -149,7 +157,8 @@ class CmdNarrateC(TalemateCommand):
self.narrator_message(message)
self.scene.push_history(message)
@register
class CmdNarrateDialogue(TalemateCommand):
"""
@@ -165,23 +174,25 @@ class CmdNarrateDialogue(TalemateCommand):
narrator = self.scene.get_helper("narrator")
character_messages = self.scene.collect_messages("character", max_iterations=5)
if not character_messages:
self.system_message("No recent dialogue message found")
return True
character_message = character_messages[0]
character_name = character_message.character_name
character = self.scene.get_character(character_name)
if not character:
self.system_message(f"Character not found: {character_name}")
return True
narration = await narrator.agent.narrate_after_dialogue(character)
message = NarratorMessage(narration, source=f"narrate_dialogue:{character.name}")
message = NarratorMessage(
narration, source=f"narrate_dialogue:{character.name}"
)
self.narrator_message(message)
self.scene.push_history(message)
self.scene.push_history(message)

View File

@@ -20,7 +20,7 @@ class CmdRebuildArchive(TalemateCommand):
if not summarizer:
self.system_message("No summarizer found")
return True
# clear out archived history, but keep pre-established history
self.scene.archived_history = [
ah for ah in self.scene.archived_history if ah.get("end") is None

View File

@@ -14,38 +14,37 @@ class CmdRemoveCharacter(TalemateCommand):
aliases = ["rmc"]
async def run(self):
characters = list([character.name for character in self.scene.get_characters()])
if not characters:
self.system_message("No characters found")
return True
if self.args:
character_name = self.args[0]
else:
character_name = await wait_for_input("Which character do you want to remove?", data={
"input_type": "select",
"choices": characters,
})
character_name = await wait_for_input(
"Which character do you want to remove?",
data={
"input_type": "select",
"choices": characters,
},
)
if not character_name:
self.system_message("No character selected")
return True
character = self.scene.get_character(character_name)
if not character:
self.system_message(f"Character {character_name} not found")
return True
await self.scene.remove_actor(character.actor)
self.system_message(f"Removed {character.name} from scene")
self.scene.emit_status()
return True

View File

@@ -2,7 +2,6 @@ import asyncio
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import wait_for_input
@@ -23,20 +22,23 @@ class CmdRename(TalemateCommand):
character_name = self.args[0]
else:
character_names = self.scene.character_names
character_name = await wait_for_input("Which character do you want to rename?", data={
"input_type": "select",
"choices": character_names,
})
character_name = await wait_for_input(
"Which character do you want to rename?",
data={
"input_type": "select",
"choices": character_names,
},
)
character = self.scene.get_character(character_name)
if not character:
self.system_message(f"Character {character_name} not found")
return True
name = await wait_for_input("Enter new name: ")
character.rename(name)
await asyncio.sleep(0)
return True

View File

@@ -1,8 +1,7 @@
from talemate.client.context import ClientContext
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.client.context import ClientContext
from talemate.context import RerunContext
from talemate.emit import wait_for_input
__all__ = [
@@ -10,6 +9,7 @@ __all__ = [
"CmdRerunWithDirection",
]
@register
class CmdRerun(TalemateCommand):
"""
@@ -24,8 +24,8 @@ class CmdRerun(TalemateCommand):
nuke_repetition = self.args[0] if self.args else 0.0
with ClientContext(nuke_repetition=nuke_repetition):
await self.scene.rerun()
@register
class CmdRerunWithDirection(TalemateCommand):
"""
@@ -35,25 +35,25 @@ class CmdRerunWithDirection(TalemateCommand):
name = "rerun_directed"
description = "Rerun the scene with a direction"
aliases = ["rrd"]
label = "Directed Rerun"
async def run(self):
nuke_repetition = self.args[0] if self.args else 0.0
method = self.args[1] if len(self.args) > 1 else "replace"
if method not in ["replace", "edit"]:
raise ValueError(f"Unknown method: {method}. Valid methods are 'replace' and 'edit'.")
raise ValueError(
f"Unknown method: {method}. Valid methods are 'replace' and 'edit'."
)
if method == "replace":
hint = ""
else:
hint = " (subtle change to previous generation)"
direction = await wait_for_input(f"Instructions for regeneration{hint}: ")
with RerunContext(self.scene, direction=direction, method=method):
with ClientContext(direction=direction, nuke_repetition=nuke_repetition):
await self.scene.rerun()
await self.scene.rerun()

View File

@@ -1,7 +1,6 @@
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import wait_for_input, wait_for_input_yesno, emit
from talemate.emit import emit, wait_for_input, wait_for_input_yesno
from talemate.exceptions import ResetScene
@@ -16,13 +15,12 @@ class CmdReset(TalemateCommand):
aliases = [""]
async def run(self):
reset = await wait_for_input_yesno("Reset the scene?")
if reset.lower() not in ["yes", "y"]:
self.system_message("Reset cancelled")
return True
self.scene.reset()
raise ResetScene()

View File

@@ -1,7 +1,6 @@
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import wait_for_input, wait_for_input_yesno, emit
from talemate.emit import emit, wait_for_input, wait_for_input_yesno
from talemate.exceptions import ResetScene
@@ -14,26 +13,25 @@ class CmdHeliosTest(TalemateCommand):
name = "helios_test"
description = "Runs the helios test"
aliases = [""]
analyst_script = [
"Good morning helios, how are you today? Are you ready to run some tests?",
]
async def run(self):
if self.scene.name != "Helios Test Arena":
emit("system", "You are not in the Helios Test Arena")
self.scene.reset()
self.scene
player = self.scene.get_player_character()
player.actor.muted = 10
analyst = self.scene.get_character("The analyst")
actor = analyst.actor
actor.script = self.analyst_script
raise ResetScene()

View File

@@ -2,9 +2,8 @@ import asyncio
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.exceptions import RestartSceneLoop
from talemate.emit import emit
from talemate.exceptions import RestartSceneLoop
@register
@@ -19,21 +18,20 @@ class CmdSetEnvironmentToScene(TalemateCommand):
async def run(self):
await asyncio.sleep(0)
player_character = self.scene.get_player_character()
if not player_character:
self.system_message("No player character found")
return True
self.scene.set_environment("scene")
emit("status", message="Switched to gameplay", status="info")
raise RestartSceneLoop()
@register
class CmdSetEnvironmentToCreative(TalemateCommand):
"""
@@ -45,8 +43,7 @@ class CmdSetEnvironmentToCreative(TalemateCommand):
aliases = [""]
async def run(self):
await asyncio.sleep(0)
self.scene.set_environment("creative")
raise RestartSceneLoop()

View File

@@ -5,16 +5,18 @@ Commands to manage scene timescale
import asyncio
import logging
import isodate
import talemate.instance as instance
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import wait_for_input
import talemate.instance as instance
import isodate
__all__ = [
"CmdAdvanceTime",
]
@register
class CmdAdvanceTime(TalemateCommand):
"""
@@ -29,20 +31,23 @@ class CmdAdvanceTime(TalemateCommand):
if not self.args:
self.emit("system", "You must specify an amount of time to advance")
return
narrator = instance.get_agent("narrator")
narration_prompt = None
# if narrator has narrate_time_passage action enabled ask the user
# for a prompt to guide the narration
if narrator.actions["narrate_time_passage"].enabled and narrator.actions["narrate_time_passage"].config["ask_for_prompt"].value:
narration_prompt = await wait_for_input("Enter a prompt to guide the time passage narration (or leave blank): ")
if (
narrator.actions["narrate_time_passage"].enabled
and narrator.actions["narrate_time_passage"].config["ask_for_prompt"].value
):
narration_prompt = await wait_for_input(
"Enter a prompt to guide the time passage narration (or leave blank): "
)
if not narration_prompt.strip():
narration_prompt = None
world_state = instance.get_agent("world_state")
await world_state.advance_time(self.args[0], narration_prompt)
await world_state.advance_time(self.args[0], narration_prompt)

View File

@@ -3,13 +3,14 @@ import logging
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.prompts.base import set_default_sectioning_handler
from talemate.instance import get_agent
from talemate.prompts.base import set_default_sectioning_handler
__all__ = [
"CmdTestTTS",
]
@register
class CmdTestTTS(TalemateCommand):
"""
@@ -22,12 +23,10 @@ class CmdTestTTS(TalemateCommand):
async def run(self):
tts_agent = get_agent("tts")
try:
last_message = str(self.scene.history[-1])
except IndexError:
last_message = "Welcome to talemate!"
await tts_agent.generate(last_message)

View File

@@ -1,13 +1,14 @@
import random
import structlog
from talemate.commands.base import TalemateCommand
from talemate.scene_message import NarratorMessage
from talemate.commands.manager import register
from talemate.emit import wait_for_input, emit
from talemate.instance import get_agent
import talemate.instance as instance
from talemate.status import set_loading, LoadingStatus
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import emit, wait_for_input
from talemate.instance import get_agent
from talemate.scene_message import NarratorMessage
from talemate.status import LoadingStatus, set_loading
log = structlog.get_logger("talemate.cmd.world_state")
@@ -22,6 +23,7 @@ __all__ = [
"CmdSummarizeAndPin",
]
@register
class CmdWorldState(TalemateCommand):
"""
@@ -33,282 +35,328 @@ class CmdWorldState(TalemateCommand):
aliases = ["ws"]
async def run(self):
inline = self.args[0] == "inline" if self.args else False
reset = self.args[0] == "reset" if self.args else False
if inline:
await self.scene.world_state.request_update_inline()
return True
if reset:
self.scene.world_state.reset()
await self.scene.world_state.request_update()
@register
class CmdPersistCharacter(TalemateCommand):
"""
Will attempt to create an actual character from a currently non
tracked character in the scene, by name.
Once persisted this character can then participate in the scene.
"""
name = "persist_character"
description = "Persist a character by name"
aliases = ["pc"]
@set_loading("Generating character...", set_busy=False)
async def run(self):
from talemate.tale_mate import Character, Actor
from talemate.tale_mate import Actor, Character
scene = self.scene
world_state = instance.get_agent("world_state")
creator = instance.get_agent("creator")
narrator = instance.get_agent("narrator")
loading_status = LoadingStatus(3)
if not len(self.args):
characters = await world_state.identify_characters()
available_names = [character["name"] for character in characters.get("characters") if not scene.get_character(character["name"])]
available_names = [
character["name"]
for character in characters.get("characters")
if not scene.get_character(character["name"])
]
if not len(available_names):
raise ValueError("No characters available to persist.")
name = await wait_for_input("Which character would you like to persist?", data={
"input_type": "select",
"choices": available_names,
"multi_select": False,
})
name = await wait_for_input(
"Which character would you like to persist?",
data={
"input_type": "select",
"choices": available_names,
"multi_select": False,
},
)
else:
name = self.args[0]
extra_instructions = None
if name == "prompt":
name = await wait_for_input("What is the name of the character?")
description = await wait_for_input(f"Brief description for {name} (or leave blank):")
description = await wait_for_input(
f"Brief description for {name} (or leave blank):"
)
if description.strip():
extra_instructions = f"Name: {name}\nBrief Description: {description}"
never_narrate = len(self.args) > 1 and self.args[1] == "no"
if not never_narrate:
is_present = await world_state.is_character_present(name)
log.debug("persist_character", name=name, is_present=is_present, never_narrate=never_narrate)
log.debug(
"persist_character",
name=name,
is_present=is_present,
never_narrate=never_narrate,
)
else:
is_present = False
log.debug("persist_character", name=name, never_narrate=never_narrate)
character = Character(name=name)
character.color = random.choice(['#F08080', '#FFD700', '#90EE90', '#ADD8E6', '#DDA0DD', '#FFB6C1', '#FAFAD2', '#D3D3D3', '#B0E0E6', '#FFDEAD'])
character.color = random.choice(
[
"#F08080",
"#FFD700",
"#90EE90",
"#ADD8E6",
"#DDA0DD",
"#FFB6C1",
"#FAFAD2",
"#D3D3D3",
"#B0E0E6",
"#FFDEAD",
]
)
loading_status("Generating character attributes...")
attributes = await world_state.extract_character_sheet(name=name, text=extra_instructions)
attributes = await world_state.extract_character_sheet(
name=name, text=extra_instructions
)
scene.log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
loading_status("Generating character description...")
description = await creator.determine_character_description(character)
character.description = description
scene.log.debug("persist_character", description=description)
actor = Actor(character=character, agent=instance.get_agent("conversation"))
await scene.add_actor(actor)
emit("status", message=f"Added character {name} to the scene.", status="success")
emit(
"status", message=f"Added character {name} to the scene.", status="success"
)
# write narrative for the character entering the scene
if not is_present and not never_narrate:
loading_status("Narrating character entrance...")
entry_narration = await narrator.narrate_character_entry(character, direction=extra_instructions)
message = NarratorMessage(entry_narration, source=f"narrate_character_entry:{character.name}")
entry_narration = await narrator.narrate_character_entry(
character, direction=extra_instructions
)
message = NarratorMessage(
entry_narration, source=f"narrate_character_entry:{character.name}"
)
self.narrator_message(message)
self.scene.push_history(message)
scene.emit_status()
scene.world_state.emit()
@register
class CmdAddReinforcement(TalemateCommand):
"""
Will attempt to create an actual character from a currently non
tracked character in the scene, by name.
Once persisted this character can then participate in the scene.
"""
name = "add_reinforcement"
description = "Add a reinforcement to the world state"
aliases = ["ws_ar"]
async def run(self):
scene = self.scene
world_state = scene.world_state
if not len(self.args):
question = await wait_for_input("Ask reinforcement question")
else:
question = self.args[0]
await world_state.add_reinforcement(question)
@register
class CmdRemoveReinforcement(TalemateCommand):
"""
Will attempt to create an actual character from a currently non
tracked character in the scene, by name.
Once persisted this character can then participate in the scene.
"""
name = "remove_reinforcement"
description = "Remove a reinforcement from the world state"
aliases = ["ws_rr"]
async def run(self):
scene = self.scene
world_state = scene.world_state
if not len(self.args):
question = await wait_for_input("Ask reinforcement question")
else:
question = self.args[0]
idx, reinforcement = await world_state.find_reinforcement(question)
if idx is None:
raise ValueError(f"Reinforcement {question} not found.")
await world_state.remove_reinforcement(idx)
@register
class CmdUpdateReinforcements(TalemateCommand):
"""
Will attempt to create an actual character from a currently non
tracked character in the scene, by name.
Once persisted this character can then participate in the scene.
"""
name = "update_reinforcements"
description = "Update the reinforcements in the world state"
aliases = ["ws_ur"]
async def run(self):
scene = self.scene
world_state = get_agent("world_state")
await world_state.update_reinforcements(force=True)
@register
class CmdCheckPinConditions(TalemateCommand):
"""
Will attempt to create an actual character from a currently non
tracked character in the scene, by name.
Once persisted this character can then participate in the scene.
"""
name = "check_pin_conditions"
description = "Check the pin conditions in the world state"
aliases = ["ws_cpc"]
async def run(self):
world_state = get_agent("world_state")
await world_state.check_pin_conditions()
@register
class CmdApplyWorldStateTemplate(TalemateCommand):
"""
Will apply a world state template setting up
automatic state tracking.
"""
name = "apply_world_state_template"
description = "Apply a world state template, creating an auto state reinforcement."
aliases = ["ws_awst"]
label = "Add state"
async def run(self):
scene = self.scene
if not len(self.args):
raise ValueError("No template name provided.")
template_name = self.args[0]
template_type = self.args[1] if len(self.args) > 1 else None
character_name = self.args[2] if len(self.args) > 2 else None
templates = await self.scene.world_state_manager.get_templates()
try:
template = getattr(templates,template_type)[template_name]
template = getattr(templates, template_type)[template_name]
except KeyError:
raise ValueError(f"Template {template_name} not found.")
reinforcement = await scene.world_state_manager.apply_template_state_reinforcement(
template, character_name=character_name, run_immediately=True
reinforcement = (
await scene.world_state_manager.apply_template_state_reinforcement(
template, character_name=character_name, run_immediately=True
)
)
response_data = {
"template_name": template_name,
"template_type": template_type,
"reinforcement": reinforcement.model_dump() if reinforcement else None,
"character_name": character_name,
}
if reinforcement is None:
emit("status", message="State already tracked.", status="info", data=response_data)
emit(
"status",
message="State already tracked.",
status="info",
data=response_data,
)
else:
emit("status", message="Auto state added.", status="success", data=response_data)
emit(
"status",
message="Auto state added.",
status="success",
data=response_data,
)
@register
class CmdSummarizeAndPin(TalemateCommand):
"""
Will take a message index and then walk back N messages
summarizing the scene and pinning it to the context.
"""
name = "summarize_and_pin"
label = "Summarize and pin"
description = "Summarize a snapshot of the scene and pin it to the world state"
aliases = ["ws_sap"]
async def run(self):
scene = self.scene
world_state = get_agent("world_state")
if not self.scene.history:
raise ValueError("No history to summarize.")
message_id = int(self.args[0]) if len(self.args) else scene.history[-1].id
num_messages = int(self.args[1]) if len(self.args) > 1 else 5
await world_state.summarize_and_pin(message_id, num_messages=num_messages)
await world_state.summarize_and_pin(message_id, num_messages=num_messages)

View File

@@ -1,8 +1,10 @@
from talemate.emit import Emitter, AbortCommand
import structlog
from talemate.emit import AbortCommand, Emitter
log = structlog.get_logger("talemate.commands.manager")
class Manager(Emitter):
"""
TaleMateCommand class to handle user command
@@ -38,7 +40,7 @@ class Manager(Emitter):
cmd_args = ""
if not self.is_command(cmd):
return False
if ":" in cmd:
# split command name and args which are separated by a colon
cmd_name, cmd_args = cmd[1:].split(":", 1)
@@ -46,7 +48,7 @@ class Manager(Emitter):
else:
cmd_name = cmd[1:]
cmd_args = []
for command_cls in self.command_classes:
if command_cls.is_command(cmd_name):
command = command_cls(self, *cmd_args)

View File

@@ -1,11 +1,11 @@
import yaml
import datetime
import os
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Union
import pydantic
import structlog
import os
import datetime
import yaml
from pydantic import BaseModel, Field
from typing import Optional, Dict, Union, ClassVar, TYPE_CHECKING
from talemate.emit import emit
from talemate.scene_assets import Asset
@@ -15,40 +15,44 @@ if TYPE_CHECKING:
log = structlog.get_logger("talemate.config")
class Client(BaseModel):
type: str
name: str
model: Union[str,None] = None
api_url: Union[str,None] = None
api_key: Union[str,None] = None
max_token_length: Union[int,None] = None
model: Union[str, None] = None
api_url: Union[str, None] = None
api_key: Union[str, None] = None
max_token_length: int = 4096
class Config:
extra = "ignore"
class AgentActionConfig(BaseModel):
value: Union[int, float, str, bool, None] = None
class AgentAction(BaseModel):
enabled: bool = True
config: Union[dict[str, AgentActionConfig], None] = None
class Agent(BaseModel):
name: Union[str,None] = None
client: Union[str,None] = None
name: Union[str, None] = None
client: Union[str, None] = None
actions: Union[dict[str, AgentAction], None] = None
enabled: bool = True
class Config:
extra = "ignore"
# change serialization so actions and enabled are only
# serialized if they are not None
def model_dump(self, **kwargs):
return super().model_dump(exclude_none=True)
class GamePlayerCharacter(BaseModel):
name: str = ""
color: str = "#3362bb"
@@ -58,10 +62,12 @@ class GamePlayerCharacter(BaseModel):
class Config:
extra = "ignore"
class General(BaseModel):
auto_save: bool = True
auto_progress: bool = True
class StateReinforcementTemplate(BaseModel):
name: str
query: str
@@ -72,52 +78,68 @@ class StateReinforcementTemplate(BaseModel):
interval: int = 10
auto_create: bool = False
favorite: bool = False
type:ClassVar = "state_reinforcement"
type: ClassVar = "state_reinforcement"
class WorldStateTemplates(BaseModel):
state_reinforcement: dict[str, StateReinforcementTemplate] = pydantic.Field(default_factory=dict)
state_reinforcement: dict[str, StateReinforcementTemplate] = pydantic.Field(
default_factory=dict
)
class WorldState(BaseModel):
templates: WorldStateTemplates = WorldStateTemplates()
templates: WorldStateTemplates = WorldStateTemplates()
class Game(BaseModel):
default_player_character: GamePlayerCharacter = GamePlayerCharacter()
general: General = General()
world_state: WorldState = WorldState()
class Config:
extra = "ignore"
class CreatorConfig(BaseModel):
content_context: list[str] = ["a fun and engaging slice of life story aimed at an adult audience."]
content_context: list[str] = [
"a fun and engaging slice of life story aimed at an adult audience."
]
class OpenAIConfig(BaseModel):
api_key: Union[str,None]=None
api_key: Union[str, None] = None
class RunPodConfig(BaseModel):
api_key: Union[str,None]=None
api_key: Union[str, None] = None
class ElevenLabsConfig(BaseModel):
api_key: Union[str,None]=None
api_key: Union[str, None] = None
model: str = "eleven_turbo_v2"
class CoquiConfig(BaseModel):
api_key: Union[str,None]=None
api_key: Union[str, None] = None
class TTSVoiceSamples(BaseModel):
label:str
value:str
label: str
value: str
class TTSConfig(BaseModel):
device:str = "cuda"
model:str = "tts_models/multilingual/multi-dataset/xtts_v2"
device: str = "cuda"
model: str = "tts_models/multilingual/multi-dataset/xtts_v2"
voices: list[TTSVoiceSamples] = pydantic.Field(default_factory=list)
class ChromaDB(BaseModel):
instructor_device: str="cpu"
instructor_model: str="default"
embeddings: str="default"
instructor_device: str = "cpu"
instructor_model: str = "default"
embeddings: str = "default"
class RecentScene(BaseModel):
name: str
@@ -125,43 +147,48 @@ class RecentScene(BaseModel):
filename: str
date: str
cover_image: Union[Asset, None] = None
class RecentScenes(BaseModel):
scenes: list[RecentScene] = pydantic.Field(default_factory=list)
max_entries: int = 10
def push(self, scene:"Scene"):
def push(self, scene: "Scene"):
"""
adds a scene to the recent scenes list
"""
# if scene has not been saved, don't add it
if not scene.full_path:
return
now = datetime.datetime.now()
# remove any existing entries for this scene
self.scenes = [s for s in self.scenes if s.path != scene.full_path]
# add the new entry
self.scenes.insert(0,
self.scenes.insert(
0,
RecentScene(
name=scene.name,
path=scene.full_path,
name=scene.name,
path=scene.full_path,
filename=scene.filename,
date=now.isoformat(),
cover_image=scene.assets.assets[scene.assets.cover_image] if scene.assets.cover_image else None
))
date=now.isoformat(),
cover_image=scene.assets.assets[scene.assets.cover_image]
if scene.assets.cover_image
else None,
),
)
# trim the list to max_entries
self.scenes = self.scenes[:self.max_entries]
self.scenes = self.scenes[: self.max_entries]
def clean(self):
"""
removes any entries that no longer exist
"""
self.scenes = [s for s in self.scenes if os.path.exists(s.path)]
@@ -170,46 +197,50 @@ class Config(BaseModel):
game: Game
agents: Dict[str, Agent] = {}
creator: CreatorConfig = CreatorConfig()
openai: OpenAIConfig = OpenAIConfig()
runpod: RunPodConfig = RunPodConfig()
chromadb: ChromaDB = ChromaDB()
elevenlabs: ElevenLabsConfig = ElevenLabsConfig()
coqui: CoquiConfig = CoquiConfig()
tts: TTSConfig = TTSConfig()
recent_scenes: RecentScenes = RecentScenes()
class Config:
extra = "ignore"
def save(self, file_path: str = "./config.yaml"):
save_config(self, file_path)
class SceneConfig(BaseModel):
automated_actions: dict[str, bool]
class SceneAssetUpload(BaseModel):
scene_cover_image:bool
character_cover_image:str = None
content:str = None
def load_config(file_path: str = "./config.yaml", as_model:bool=False) -> Union[dict, Config]:
class SceneAssetUpload(BaseModel):
scene_cover_image: bool
character_cover_image: str = None
content: str = None
def load_config(
file_path: str = "./config.yaml", as_model: bool = False
) -> Union[dict, Config]:
"""
Load the config file from the given path.
Should cache the config and only reload if the file modification time
has changed since the last load
"""
with open(file_path, "r") as file:
config_data = yaml.safe_load(file)
@@ -225,13 +256,14 @@ def load_config(file_path: str = "./config.yaml", as_model:bool=False) -> Union[
return config.model_dump()
def save_config(config, file_path: str = "./config.yaml"):
"""
Save the config file to the given path.
"""
log.debug("Saving config", file_path=file_path)
# If config is a Config instance, convert it to a dictionary
if isinstance(config, Config):
config = config.model_dump(exclude_none=True)
@@ -245,5 +277,5 @@ def save_config(config, file_path: str = "./config.yaml"):
with open(file_path, "w") as file:
yaml.dump(config, file)
emit("config_saved", data=config)
emit("config_saved", data=config)

View File

@@ -1,4 +1,5 @@
from contextvars import ContextVar
import structlog
__all__ = [
@@ -13,29 +14,34 @@ log = structlog.get_logger(__name__)
scene_is_loading = ContextVar("scene_is_loading", default=None)
rerun_context = ContextVar("rerun_context", default=None)
class SceneIsLoading:
def __init__(self, scene):
self.scene = scene
def __enter__(self):
self.token = scene_is_loading.set(self.scene)
def __exit__(self, *args):
scene_is_loading.reset(self.token)
class RerunContext:
def __init__(self, scene, direction=None, method="replace", message:str = None):
def __init__(self, scene, direction=None, method="replace", message: str = None):
self.scene = scene
self.direction = direction
self.method = method
self.message = message
log.debug("RerunContext", scene=scene, direction=direction, method=method, message=message)
log.debug(
"RerunContext",
scene=scene,
direction=direction,
method=method,
message=message,
)
def __enter__(self):
self.token = rerun_context.set(self)
def __exit__(self, *args):
rerun_context.reset(self.token)
rerun_context.reset(self.token)

View File

@@ -4,9 +4,10 @@ __all__ = [
"ArchiveEntry",
]
@dataclass
class ArchiveEntry:
text: str
start: int = None
end: int = None
ts: str = None
ts: str = None

View File

@@ -1,57 +1,56 @@
handlers = {
}
handlers = {}
class AsyncSignal:
def __init__(self, name):
self.receivers = []
self.name = name
def connect(self, handler):
if handler in self.receivers:
return
self.receivers.append(handler)
def disconnect(self, handler):
self.receivers.remove(handler)
async def send(self, emission):
for receiver in self.receivers:
await receiver(emission)
def _register(name:str):
def _register(name: str):
"""
Registers a signal handler
Arguments:
name (str): The name of the signal
handler (signal): The signal handler
"""
if name in handlers:
raise ValueError(f"Signal {name} already registered")
handlers[name] = AsyncSignal(name)
return handlers[name]
def register(*names):
"""
Registers many signal handlers
Arguments:
*names (str): The names of the signals
"""
for name in names:
_register(name)
def get(name:str):
def get(name: str):
"""
Gets a signal handler
Arguments:
name (str): The name of the signal handler
"""
return handlers.get(name)
return handlers.get(name)

View File

@@ -2,13 +2,14 @@ from __future__ import annotations
import asyncio
import dataclasses
import structlog
from typing import TYPE_CHECKING, Any
from .signals import handlers
import structlog
from talemate.scene_message import SceneMessage
from .signals import handlers
if TYPE_CHECKING:
from talemate.tale_mate import Character, Scene
@@ -21,6 +22,7 @@ __all__ = [
log = structlog.get_logger("talemate.emit.base")
class AbortCommand(IOError):
pass
@@ -39,12 +41,15 @@ class Emission:
def emit(
typ: str, message: str = None, character: Character = None, scene: Scene = None, **kwargs
typ: str,
message: str = None,
character: Character = None,
scene: Scene = None,
**kwargs,
):
if typ not in handlers:
raise ValueError(f"Unknown message type: {typ}")
if isinstance(message, SceneMessage):
kwargs["id"] = message.id
message_object = message
@@ -53,7 +58,14 @@ def emit(
message_object = None
handlers[typ].send(
Emission(typ=typ, message=message, character=character, scene=scene, message_object=message_object, **kwargs)
Emission(
typ=typ,
message=message,
character=character,
scene=scene,
message_object=message_object,
**kwargs,
)
)
@@ -80,7 +92,6 @@ async def wait_for_input(
def input_receiver(emission: Emission):
input_received["message"] = emission.message
handlers["receive_input"].connect(input_receiver)
handlers["request_input"].send(
@@ -97,7 +108,7 @@ async def wait_for_input(
await asyncio.sleep(0.1)
handlers["receive_input"].disconnect(input_receiver)
if input_received["message"] == "!abort":
raise AbortCommand()
@@ -145,4 +156,4 @@ class Emitter:
self.emit("character", message, character=character)
def player_message(self, message: str, character: Character):
self.emit("player", message, character=character)
self.emit("player", message, character=character)

View File

@@ -18,7 +18,7 @@ ClientStatus = signal("client_status")
RequestClientStatus = signal("request_client_status")
AgentStatus = signal("agent_status")
RequestAgentStatus = signal("request_agent_status")
ClientBootstraps = signal("client_bootstraps")
ClientBootstraps = signal("client_bootstraps")
PromptSent = signal("prompt_sent")
RemoveMessage = signal("remove_message")

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from talemate.tale_mate import Scene, Actor, SceneMessage
from talemate.tale_mate import Actor, Scene, SceneMessage
__all__ = [
"Event",
@@ -35,27 +35,33 @@ class CharacterStateEvent(Event):
state: str
character_name: str
@dataclass
class SceneStateEvent(Event):
pass
@dataclass
class GameLoopBase(Event):
pass
@dataclass
class GameLoopEvent(GameLoopBase):
had_passive_narration: bool = False
@dataclass
class GameLoopStartEvent(GameLoopBase):
pass
@dataclass
class GameLoopActorIterEvent(GameLoopBase):
actor: Actor
game_loop: GameLoopEvent
@dataclass
class GameLoopNewMessageEvent(GameLoopBase):
message: SceneMessage
message: SceneMessage

View File

@@ -1,6 +1,7 @@
class TalemateError(Exception):
pass
class TalemateInterrupt(Exception):
"""
Exception to interrupt the game loop
@@ -8,6 +9,7 @@ class TalemateInterrupt(Exception):
pass
class ExitScene(TalemateInterrupt):
"""
Exception to exit the scene
@@ -15,18 +17,20 @@ class ExitScene(TalemateInterrupt):
pass
class RestartSceneLoop(TalemateInterrupt):
"""
Exception to switch the scene loop
"""
pass
class ResetScene(TalemateInterrupt):
"""
Exception to reset the scene
"""
pass
@@ -34,7 +38,7 @@ class RenderPromptError(TalemateError):
"""
Exception to raise when there is an error rendering a prompt
"""
pass
@@ -42,11 +46,10 @@ class LLMAccuracyError(TalemateError):
"""
Exception to raise when the LLM response is not processable
"""
def __init__(self, message:str, model_name:str=None):
def __init__(self, message: str, model_name: str = None):
if model_name:
message = f"{model_name} - {message}"
super().__init__(message)
self.model_name = model_name
self.model_name = model_name

View File

@@ -1,5 +1,5 @@
import os
import fnmatch
import os
from talemate.config import load_config
@@ -27,7 +27,7 @@ def _list_files_and_directories(root: str, path: str) -> list:
:return: List of files and directories in the given root directory.
"""
# Define the file patterns to match
patterns = ['characters/*.png', 'characters/*.webp', '*/*.json']
patterns = ["characters/*.png", "characters/*.webp", "*/*.json"]
items = []
@@ -42,4 +42,4 @@ def _list_files_and_directories(root: str, path: str) -> list:
items.append(os.path.join(dirpath, filename))
break
return items
return items

View File

@@ -1,56 +1,62 @@
import asyncio
import os
from typing import TYPE_CHECKING, Any
import nest_asyncio
import pydantic
import structlog
import asyncio
import nest_asyncio
from talemate.prompts.base import Prompt, PrependTemplateDirectories
from talemate.instance import get_agent
from talemate.agents.director import DirectorAgent
from talemate.agents.memory import MemoryAgent
from talemate.instance import get_agent
from talemate.prompts.base import PrependTemplateDirectories, Prompt
if TYPE_CHECKING:
from talemate.tale_mate import Scene
log = structlog.get_logger("game_state")
class Goal(pydantic.BaseModel):
description: str
id: int
status: bool = False
class Instructions(pydantic.BaseModel):
character: dict[str, str] = pydantic.Field(default_factory=dict)
class Ops(pydantic.BaseModel):
run_on_start: bool = False
class GameState(pydantic.BaseModel):
ops: Ops = Ops()
variables: dict[str,Any] = pydantic.Field(default_factory=dict)
variables: dict[str, Any] = pydantic.Field(default_factory=dict)
goals: list[Goal] = pydantic.Field(default_factory=list)
instructions: Instructions = pydantic.Field(default_factory=Instructions)
@property
def director(self) -> DirectorAgent:
return get_agent('director')
return get_agent("director")
@property
def memory(self) -> MemoryAgent:
return get_agent('memory')
return get_agent("memory")
@property
def scene(self) -> 'Scene':
def scene(self) -> "Scene":
return self.director.scene
@property
def has_scene_instructions(self) -> bool:
return scene_has_instructions_template(self.scene)
@property
def game_won(self) -> bool:
return self.variables.get("__game_won__") == True
@property
def scene_instructions(self) -> str:
scene = self.scene
@@ -59,43 +65,52 @@ class GameState(pydantic.BaseModel):
game_state = self
if scene_has_instructions_template(self.scene):
with PrependTemplateDirectories([scene.template_dir]):
prompt = Prompt.get('instructions', {
'scene': scene,
'max_tokens': client.max_token_length,
'game_state': game_state
})
prompt = Prompt.get(
"instructions",
{
"scene": scene,
"max_tokens": client.max_token_length,
"game_state": game_state,
},
)
prompt.client = client
instructions = prompt.render().strip()
log.info("Initialized game state instructions", scene=scene, instructions=instructions)
log.info(
"Initialized game state instructions",
scene=scene,
instructions=instructions,
)
return instructions
def init(self, scene: 'Scene') -> 'GameState':
def init(self, scene: "Scene") -> "GameState":
return self
def set_var(self, key: str, value: Any, commit: bool = False):
self.variables[key] = value
if commit:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.memory.add(value, uid=f"game_state.{key}"))
def has_var(self, key: str) -> bool:
return key in self.variables
def get_var(self, key: str) -> Any:
return self.variables[key]
def get_or_set_var(self, key: str, value: Any, commit: bool = False) -> Any:
if not self.has_var(key):
self.set_var(key, value, commit=commit)
return self.get_var(key)
def scene_has_game_template(scene: 'Scene') -> bool:
def scene_has_game_template(scene: "Scene") -> bool:
"""Returns True if the scene has a game template."""
game_template_path = os.path.join(scene.template_dir, 'game.jinja2')
game_template_path = os.path.join(scene.template_dir, "game.jinja2")
return os.path.exists(game_template_path)
def scene_has_instructions_template(scene: 'Scene') -> bool:
def scene_has_instructions_template(scene: "Scene") -> bool:
"""Returns True if the scene has an instructions template."""
instructions_template_path = os.path.join(scene.template_dir, 'instructions.jinja2')
instructions_template_path = os.path.join(scene.template_dir, "instructions.jinja2")
return os.path.exists(instructions_template_path)

View File

@@ -2,21 +2,21 @@
Keep track of clients and agents
"""
import asyncio
import talemate.agents as agents
import talemate.client as clients
from talemate.emit import emit
from talemate.emit.signals import handlers
import talemate.client.bootstrap as bootstrap
import structlog
import talemate.agents as agents
import talemate.client as clients
import talemate.client.bootstrap as bootstrap
from talemate.emit import emit
from talemate.emit.signals import handlers
log = structlog.get_logger("talemate")
AGENTS = {}
CLIENTS = {}
def get_agent(typ: str, *create_args, **create_kwargs):
agent = AGENTS.get(typ)
@@ -75,48 +75,51 @@ def client_instances():
def agent_instances():
return AGENTS.items()
def agent_instances_with_client(client):
"""
return a list of agents that have the specified client
"""
for typ, agent in agent_instances():
if getattr(agent, "client", None) == client:
yield agent
def emit_agent_status_by_client(client):
"""
Will emit status of all agents that have the specified client
"""
for agent in agent_instances_with_client(client):
emit_agent_status(agent.__class__, agent)
async def emit_clients_status():
"""
Will emit status of all clients
"""
#log.debug("emit", type="client status")
# log.debug("emit", type="client status")
for client in CLIENTS.values():
if client:
await client.status()
def _sync_emit_clients_status(*args, **kwargs):
"""
Will emit status of all clients
in synchronous mode
"""
loop = asyncio.get_event_loop()
loop.run_until_complete(emit_clients_status())
loop.run_until_complete(emit_clients_status())
handlers["request_client_status"].connect(_sync_emit_clients_status)
async def emit_client_bootstraps():
emit(
"client_bootstraps",
data=list(await bootstrap.list_all())
)
emit("client_bootstraps", data=list(await bootstrap.list_all()))
def sync_emit_clients_status():
"""
@@ -126,15 +129,20 @@ def sync_emit_clients_status():
loop = asyncio.get_event_loop()
loop.run_until_complete(emit_clients_status())
async def sync_client_bootstraps():
"""
Will loop through all registered client bootstrap lists and spawn / update
Will loop through all registered client bootstrap lists and spawn / update
client instances from them.
"""
for service_name, func in bootstrap.LISTS.items():
async for client_bootstrap in func():
log.debug("sync client bootstrap", service_name=service_name, client_bootstrap=client_bootstrap.dict())
log.debug(
"sync client bootstrap",
service_name=service_name,
client_bootstrap=client_bootstrap.dict(),
)
client = get_client(
client_bootstrap.name,
type=client_bootstrap.client_type.value,
@@ -143,6 +151,7 @@ async def sync_client_bootstraps():
)
await client.status()
def emit_agent_status(cls, agent=None):
if not agent:
emit(
@@ -167,9 +176,10 @@ def emit_agents_status(*args, **kwargs):
"""
Will emit status of all agents
"""
#log.debug("emit", type="agent status")
# log.debug("emit", type="agent status")
for typ, cls in agents.AGENT_CLASSES.items():
agent = AGENTS.get(typ)
emit_agent_status(cls, agent)
handlers["request_agent_status"].connect(emit_agents_status)
handlers["request_agent_status"].connect(emit_agents_status)

View File

@@ -1,22 +1,26 @@
import json
import os
import structlog
from dotenv import load_dotenv
import talemate.events as events
import talemate.instance as instance
from talemate import Actor, Character, Player
from talemate.config import load_config
from talemate.scene_message import (
SceneMessage, CharacterMessage, NarratorMessage, DirectorMessage, MESSAGES, reset_message_id
)
from talemate.world_state import WorldState
from talemate.game_state import GameState
from talemate.context import SceneIsLoading
from talemate.emit import emit
from talemate.status import set_loading, LoadingStatus
import talemate.instance as instance
import structlog
from talemate.game_state import GameState
from talemate.scene_message import (
MESSAGES,
CharacterMessage,
DirectorMessage,
NarratorMessage,
SceneMessage,
reset_message_id,
)
from talemate.status import LoadingStatus, set_loading
from talemate.world_state import WorldState
__all__ = [
"load_scene",
@@ -29,6 +33,7 @@ __all__ = [
log = structlog.get_logger("talemate.load")
@set_loading("Loading scene...")
async def load_scene(scene, file_path, conv_client, reset: bool = False):
"""
@@ -61,8 +66,7 @@ async def load_scene_from_character_card(scene, file_path):
"""
Load a character card (tavern etc.) from the given file path.
"""
loading_status = LoadingStatus(5)
loading_status("Loading character card...")
@@ -85,59 +89,68 @@ async def load_scene_from_character_card(scene, file_path):
actor = Actor(character, conversation)
scene.name = character.name
loading_status("Initializing long-term memory...")
await memory.set_db()
await scene.add_actor(actor)
log.debug("load_scene_from_character_card", scene=scene, character=character, content_context=scene.context)
log.debug(
"load_scene_from_character_card",
scene=scene,
character=character,
content_context=scene.context,
)
loading_status("Determine character context...")
if not scene.context:
try:
scene.context = await creator.determine_content_context_for_character(character)
scene.context = await creator.determine_content_context_for_character(
character
)
log.debug("content_context", content_context=scene.context)
except Exception as e:
log.error("determine_content_context_for_character", error=e)
# attempt to convert to base attributes
try:
loading_status("Determine character attributes...")
_, character.base_attributes = await creator.determine_character_attributes(character)
_, character.base_attributes = await creator.determine_character_attributes(
character
)
# lowercase keys
character.base_attributes = {k.lower(): v for k, v in character.base_attributes.items()}
character.base_attributes = {
k.lower(): v for k, v in character.base_attributes.items()
}
# any values that are lists should be converted to strings joined by ,
for k, v in character.base_attributes.items():
if isinstance(v, list):
character.base_attributes[k] = ",".join(v)
# transfer description to character
if character.base_attributes.get("description"):
character.description = character.base_attributes.pop("description")
await character.commit_to_memory(scene.get_helper("memory").agent)
log.debug("base_attributes parsed", base_attributes=character.base_attributes)
except Exception as e:
log.warning("determine_character_attributes", error=e)
scene.description = character.description
if image:
scene.assets.set_cover_image_from_file_path(file_path)
character.cover_image = scene.assets.cover_image
try:
loading_status("Update world state ...")
await scene.world_state.request_update(initial_only=True)
await scene.world_state.request_update(initial_only=True)
except Exception as e:
log.error("world_state.request_update", error=e)
@@ -151,9 +164,9 @@ async def load_scene_from_data(
):
loading_status = LoadingStatus(1)
reset_message_id()
memory = scene.get_helper("memory").agent
scene.description = scene_data.get("description", "")
scene.intro = scene_data.get("intro", "") or scene.description
scene.name = scene_data.get("name", "Unknown Scene")
@@ -161,11 +174,10 @@ async def load_scene_from_data(
scene.filename = None
scene.goals = scene_data.get("goals", [])
scene.immutable_save = scene_data.get("immutable_save", False)
#reset = True
# reset = True
if not reset:
scene.goal = scene_data.get("goal", 0)
scene.memory_id = scene_data.get("memory_id", scene.memory_id)
scene.saved_memory_session_id = scene_data.get("saved_memory_session_id", None)
@@ -181,33 +193,37 @@ async def load_scene_from_data(
)
scene.assets.cover_image = scene_data.get("assets", {}).get("cover_image", None)
scene.assets.load_assets(scene_data.get("assets", {}).get("assets", {}))
scene.sync_time()
log.debug("scene time", ts=scene.ts)
loading_status("Initializing long-term memory...")
await memory.set_db()
await memory.remove_unsaved_memory()
await scene.world_state_manager.remove_all_empty_pins()
if not scene.memory_session_id:
scene.set_new_memory_session_id()
for ah in scene.archived_history:
if reset:
break
ts = ah.get("ts", "PT1S")
if not ah.get("ts"):
ah["ts"] = ts
scene.signals["archive_add"].send(
events.ArchiveEvent(scene=scene, event_type="archive_add", text=ah["text"], ts=ts)
events.ArchiveEvent(
scene=scene, event_type="archive_add", text=ah["text"], ts=ts
)
)
for character_name, character_data in scene_data.get("inactive_characters", {}).items():
for character_name, character_data in scene_data.get(
"inactive_characters", {}
).items():
scene.inactive_characters[character_name] = Character(**character_data)
for character_name, cs in scene.character_states.items():
@@ -215,10 +231,10 @@ async def load_scene_from_data(
for character_data in scene_data["characters"]:
character = Character(**character_data)
if character.name in scene.inactive_characters:
scene.inactive_characters.pop(character.name)
if not character.is_player:
agent = instance.get_agent("conversation", client=conv_client)
actor = Actor(character, agent)
@@ -226,13 +242,14 @@ async def load_scene_from_data(
actor = Player(character, None)
# Add the TestCharacter actor to the scene
await scene.add_actor(actor)
# the scene has been saved before (since we just loaded it), so we set the saved flag to True
# as long as the scene has a memory_id.
scene.saved = "memory_id" in scene_data
return scene
async def load_character_into_scene(scene, scene_json_path, character_name):
"""
Load a character from a scene json file and add it to the current scene.
@@ -244,10 +261,9 @@ async def load_character_into_scene(scene, scene_json_path, character_name):
# Load the json file
with open(scene_json_path, "r") as f:
scene_data = json.load(f)
agent = scene.get_helper("conversation").agent
# Find the character in the characters list
for character_data in scene_data["characters"]:
if character_data["name"] == character_name:
@@ -264,7 +280,9 @@ async def load_character_into_scene(scene, scene_json_path, character_name):
await scene.add_actor(actor)
break
else:
raise ValueError(f"Character '{character_name}' not found in the scene file '{scene_json_path}'")
raise ValueError(
f"Character '{character_name}' not found in the scene file '{scene_json_path}'"
)
return scene
@@ -340,49 +358,47 @@ def default_player_character():
def _load_history(history):
_history = []
for text in history:
if isinstance(text, str):
_history.append(_prepare_legacy_history(text))
elif isinstance(text, dict):
_history.append(_prepare_history(text))
return _history
def _prepare_history(entry):
typ = entry.pop("typ", "scene_message")
entry.pop("id", None)
if entry.get("source") == "":
entry.pop("source")
cls = MESSAGES.get(typ, SceneMessage)
return cls(**entry)
def _prepare_legacy_history(entry):
"""
Convers legacy history to new format
Legacy: list<str>
New: list<SceneMessage>
"""
if entry.startswith("*"):
cls = NarratorMessage
elif entry.startswith("Director instructs"):
cls = DirectorMessage
else:
cls = CharacterMessage
return cls(entry)
def creative_environment():
return {
@@ -392,6 +408,5 @@ def creative_environment():
"history": [],
"archived_history": [],
"character_states": {},
"characters": [
],
"characters": [],
}

View File

@@ -1 +1 @@
from .base import Prompt, LoopedPrompt
from .base import LoopedPrompt, Prompt

File diff suppressed because it is too large Load Diff

View File

@@ -1,30 +1,32 @@
from contextvars import ContextVar
import pydantic
current_prompt_context = ContextVar("current_content_context", default=None)
class PromptContextState(pydantic.BaseModel):
content: list[str] = pydantic.Field(default_factory=list)
def push(self, content:str, proxy:list[str]):
def push(self, content: str, proxy: list[str]):
if content not in self.content:
self.content.append(content)
proxy.append(content)
def has(self, content:str):
def has(self, content: str):
return content in self.content
def extend(self, content:list[str], proxy:list[str]):
def extend(self, content: list[str], proxy: list[str]):
for item in content:
self.push(item, proxy)
class PromptContext:
def __enter__(self):
self.state = PromptContextState()
self.token = current_prompt_context.set(self.state)
return self.state
def __exit__(self, *args):
current_prompt_context.reset(self.token)
return False
return False

View File

@@ -41,7 +41,7 @@ class CharacterHub:
if not os.path.exists("scenes/characters"):
os.makedirs("scenes/characters")
with open(f"scenes/characters/{filename}.png", "wb") as f:
f.write(result.content)

View File

@@ -2,11 +2,9 @@ import json
from talemate.scene_message import SceneMessage
class SceneEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, SceneMessage):
return obj.__dict__()
return super().default(obj)

View File

@@ -1,41 +1,40 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import os
import pydantic
import hashlib
import base64
import hashlib
import os
from typing import TYPE_CHECKING, Any
import pydantic
if TYPE_CHECKING:
from talemate import Scene
import structlog
__all__ = [
"Asset",
"SceneAssets"
]
__all__ = ["Asset", "SceneAssets"]
log = structlog.get_logger("talemate.scene_assets")
class Asset(pydantic.BaseModel):
id: str
file_type: str
media_type: str
def to_base64(self, asset_directory:str) -> str:
def to_base64(self, asset_directory: str) -> str:
"""
Returns the asset as a base64 encoded string.
"""
asset_path = os.path.join(asset_directory, f"{self.id}.{self.file_type}")
with open(asset_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
class SceneAssets:
def __init__(self, scene:Scene):
class SceneAssets:
def __init__(self, scene: Scene):
self.scene = scene
self.assets = {}
self.cover_image = None
@@ -45,64 +44,64 @@ class SceneAssets:
"""
Returns the scene's asset path
"""
scene_save_dir = self.scene.save_dir
if not os.path.exists(scene_save_dir):
raise FileNotFoundError(f"Scene save directory does not exist: {scene_save_dir}")
raise FileNotFoundError(
f"Scene save directory does not exist: {scene_save_dir}"
)
asset_path = os.path.join(scene_save_dir, "assets")
if not os.path.exists(asset_path):
os.makedirs(asset_path)
return asset_path
def asset_path(self, asset_id:str) -> str:
def asset_path(self, asset_id: str) -> str:
"""
Returns the path to the asset with the given id.
"""
try:
return os.path.join(self.asset_directory, f"{asset_id}.{self.assets[asset_id].file_type}")
return os.path.join(
self.asset_directory, f"{asset_id}.{self.assets[asset_id].file_type}"
)
except KeyError:
log.error("asset_path", asset_id=asset_id, assets=self.assets)
return None
def dict(self, *args, **kwargs):
return {
"cover_image": self.cover_image,
"assets": {
asset.id: asset.dict() for asset in self.assets.values()
}
"assets": {asset.id: asset.dict() for asset in self.assets.values()},
}
def load_assets(self, assets_dict:dict):
"""
Loads assets from a dictionary.
"""
for asset_id, asset_dict in assets_dict.items():
self.assets[asset_id] = Asset(**asset_dict)
def set_cover_image(self, asset_bytes:bytes, file_extension:str, media_type:str):
def load_assets(self, assets_dict: dict):
"""
Loads assets from a dictionary.
"""
for asset_id, asset_dict in assets_dict.items():
self.assets[asset_id] = Asset(**asset_dict)
def set_cover_image(self, asset_bytes: bytes, file_extension: str, media_type: str):
# add the asset
asset = self.add_asset(asset_bytes, file_extension, media_type)
self.cover_image = asset.id
def set_cover_image_from_file_path(self, file_path:str):
def set_cover_image_from_file_path(self, file_path: str):
"""
Sets the cover image from file path, calling add_asset_from_file_path
"""
asset = self.add_asset_from_file_path(file_path)
self.cover_image = asset.id
def add_asset(self, asset_bytes:bytes, file_extension:str, media_type:str) -> Asset:
self.cover_image = asset.id
def add_asset(
self, asset_bytes: bytes, file_extension: str, media_type: str
) -> Asset:
"""
Takes the asset and stores it in the scene's assets folder.
"""
@@ -131,39 +130,36 @@ class SceneAssets:
self.assets[asset_id] = asset
return asset
def add_asset_from_image_data(self, image_data:str) -> Asset:
def add_asset_from_image_data(self, image_data: str) -> Asset:
"""
Will add an asset from an image data, extracting media type from the
data url and then decoding the base64 encoded data.
Will call add_asset
"""
media_type = image_data.split(";")[0].split(":")[1]
image_bytes = base64.b64decode(image_data.split(",")[1])
file_extension = media_type.split("/")[1]
return self.add_asset(image_bytes, file_extension, media_type)
def add_asset_from_file_path(self, file_path:str) -> Asset:
def add_asset_from_file_path(self, file_path: str) -> Asset:
"""
Will add an asset from a file path, first loading the file into memory.
and then calling add_asset
"""
file_bytes = None
with open(file_path, "rb") as f:
file_bytes = f.read()
file_extension = os.path.splitext(file_path)[1]
# guess media type from extension, currently only supports images
# for png, jpg and webp
if file_extension == ".png":
media_type = "image/png"
elif file_extension in [".jpg", ".jpeg"]:
@@ -172,50 +168,44 @@ class SceneAssets:
media_type = "image/webp"
else:
raise ValueError(f"Unsupported file extension: {file_extension}")
return self.add_asset(file_bytes, file_extension, media_type)
def get_asset(self, asset_id:str) -> Asset:
return self.add_asset(file_bytes, file_extension, media_type)
def get_asset(self, asset_id: str) -> Asset:
"""
Returns the asset with the given id.
"""
return self.assets[asset_id]
def get_asset_bytes(self, asset_id:str) -> bytes:
def get_asset_bytes(self, asset_id: str) -> bytes:
"""
Returns the bytes of the asset with the given id.
"""
asset_path = self.asset_path(asset_id)
with open(asset_path, "rb") as f:
return f.read()
def get_asset_bytes_as_base64(self, asset_id:str) -> str:
def get_asset_bytes_as_base64(self, asset_id: str) -> str:
"""
Returns the bytes of the asset with the given id as a base64 encoded string.
"""
bytes = self.get_asset_bytes(asset_id)
return base64.b64encode(bytes).decode("utf-8")
def remove_asset(self, asset_id:str):
def remove_asset(self, asset_id: str):
"""
Removes the asset with the given id.
"""
asset = self.assets.pop(asset_id)
asset_path = self.asset_directory
asset_file_path = os.path.join(asset_path, f"{asset_id}.{asset.file_type}")
os.remove(asset_file_path)
os.remove(asset_file_path)

View File

@@ -1,13 +1,16 @@
from dataclasses import dataclass, field
import isodate
_message_id = 0
def get_message_id():
global _message_id
_message_id += 1
return _message_id
def reset_message_id():
global _message_id
_message_id = 0
@@ -15,38 +18,37 @@ def reset_message_id():
@dataclass
class SceneMessage:
"""
Base class for all messages that are sent to the scene.
"""
# the mesage itself
message: str
# the id of the message
id: int = field(default_factory=get_message_id)
# the source of the message (e.g. "ai", "progress_story", "director")
source: str = ""
typ = "scene"
def __str__(self):
return self.message
def __int__(self):
return self.id
def __len__(self):
return len(self.message)
def __in__(self, other):
return (other in self.message)
return other in self.message
def __contains__(self, other):
return (self.message in other)
return self.message in other
def __dict__(self):
return {
"message": self.message,
@@ -54,66 +56,69 @@ class SceneMessage:
"typ": self.typ,
"source": self.source,
}
def __iter__(self):
return iter(self.message)
def split(self, *args, **kwargs):
return self.message.split(*args, **kwargs)
def startswith(self, *args, **kwargs):
return self.message.startswith(*args, **kwargs)
def endswith(self, *args, **kwargs):
return self.message.endswith(*args, **kwargs)
@property
def secondary_source(self):
return self.source
@dataclass
class CharacterMessage(SceneMessage):
typ = "character"
source: str = "ai"
def __str__(self):
return self.message
@property
def character_name(self):
return self.message.split(":", 1)[0]
@property
def secondary_source(self):
return self.character_name
@dataclass
class NarratorMessage(SceneMessage):
source: str = "progress_story"
typ = "narrator"
@dataclass
class DirectorMessage(SceneMessage):
typ = "director"
def __str__(self):
"""
The director message is a special case and needs to be transformed
from "Director instructs {charname}:" to "*{charname} inner monologue:"
"""
transformed_message = self.message.replace("Director instructs ", "")
char_name, message = transformed_message.split(":", 1)
return f"[Story progression instructions for {char_name}: {message}]"
@dataclass
class TimePassageMessage(SceneMessage):
ts: str = "PT0S"
source: str = "manual"
source: str = "manual"
typ = "time"
def __dict__(self):
return {
"message": self.message,
@@ -122,15 +127,17 @@ class TimePassageMessage(SceneMessage):
"source": self.source,
"ts": self.ts,
}
@dataclass
class ReinforcementMessage(SceneMessage):
typ = "reinforcement"
def __str__(self):
question, _ = self.source.split(":", 1)
return f"[Context state: {question}: {self.message}]"
MESSAGES = {
"scene": SceneMessage,
"character": CharacterMessage,
@@ -138,4 +145,4 @@ MESSAGES = {
"director": DirectorMessage,
"time": TimePassageMessage,
"reinforcement": ReinforcementMessage,
}
}

View File

@@ -1,40 +1,38 @@
import asyncio
import json
import os
import starlette.websockets
import websockets
import structlog
import traceback
import starlette.websockets
import structlog
import websockets
import talemate.instance as instance
from talemate import Scene
from talemate import VERSION, Scene
from talemate.config import load_config
from talemate.load import load_scene
from talemate.server.websocket_server import WebsocketHandler
from talemate.config import load_config
from talemate import VERSION
log = structlog.get_logger("talemate")
async def websocket_endpoint(websocket, path):
# Create a queue for outgoing messages
message_queue = asyncio.Queue()
handler = WebsocketHandler(websocket, message_queue)
scene_task = None
log.info("frontend connected")
try:
# Create a task to send messages from the queue
async def send_messages():
while True:
# check if there are messages in the queue
if message_queue.empty():
await asyncio.sleep(0.01)
continue
message = await message_queue.get()
await websocket.send(json.dumps(message))
@@ -49,15 +47,19 @@ async def websocket_endpoint(websocket, path):
send_status_task = asyncio.create_task(send_status())
# create a task that will retriece client boostrap information
async def send_client_bootstraps():
while True:
try:
await instance.sync_client_bootstraps()
except Exception as e:
log.error("send_client_bootstraps", error=e, traceback=traceback.format_exc())
log.error(
"send_client_bootstraps",
error=e,
traceback=traceback.format_exc(),
)
await asyncio.sleep(15)
send_client_bootstraps_task = asyncio.create_task(send_client_bootstraps())
while True:
@@ -66,7 +68,7 @@ async def websocket_endpoint(websocket, path):
action_type = data.get("type")
scene_data = None
log.debug("frontend message", action_type=action_type)
if action_type == "load_scene":
@@ -86,16 +88,18 @@ async def websocket_endpoint(websocket, path):
"message": "Scene file loaded ...",
"id": "scene.loaded",
"status": "success",
"data": {"hidden":True}
"data": {"hidden": True},
}
)
if scene_data and filename:
file_path = handler.handle_character_card_upload(scene_data, filename)
file_path = handler.handle_character_card_upload(
scene_data, filename
)
log.info("load_scene", file_path=file_path, reset=reset)
# Create a task to load the scene in the background
# Create a task to load the scene in the background
scene_task = asyncio.create_task(
handler.load_scene(
file_path, reset=reset, callback=scene_loading_done
@@ -140,11 +144,7 @@ async def websocket_endpoint(websocket, path):
elif action_type == "request_app_config":
log.info("request_app_config")
await message_queue.put(
{
"type": "app_config",
"data": load_config(),
"version": VERSION
}
{"type": "app_config", "data": load_config(), "version": VERSION}
)
else:
log.info("Routing to sub-handler", action_type=action_type)

View File

@@ -1,18 +1,16 @@
import os
import pydantic
import asyncio
import os
from typing import Union
import pydantic
import structlog
from talemate.prompts import Prompt
from talemate.tale_mate import Character, Actor, Player
from typing import Union
from talemate.tale_mate import Actor, Character, Player
log = structlog.get_logger("talemate.server.character_creator")
class StepData(pydantic.BaseModel):
template: str
is_player_character: bool
@@ -26,69 +24,70 @@ class StepData(pydantic.BaseModel):
description: str = None
questions: list[str] = []
scenario_context: str = None
class CharacterCreationData(pydantic.BaseModel):
base_attributes: dict[str, str] = {}
is_player_character: bool = False
character: object = None
initial_prompt: str = None
scenario_context: str = None
class CharacterCreatorServerPlugin:
router = "character_creator"
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
self.character_creation_data = None
@property
def scene(self):
return self.websocket_handler.scene
async def handle(self, data:dict):
async def handle(self, data: dict):
action = data.get("action")
log.info("Character creator action", action=action)
if action == "submit":
step = data.get("step")
fn = getattr(self, f"handle_submit_step{step}", None)
if fn is None:
raise NotImplementedError(f"Unknown step {step}")
return await fn(data)
elif action == "request_templates":
return await self.handle_request_templates(data)
async def handle_request_templates(self, data:dict):
choices = Prompt.get("creator.character-human",{}).list_templates("character-attributes-*.jinja2")
async def handle_request_templates(self, data: dict):
choices = Prompt.get("creator.character-human", {}).list_templates(
"character-attributes-*.jinja2"
)
# each choice is a filename, we want to remove the extension and the directory
choices = [os.path.splitext(os.path.basename(c))[0] for c in choices]
# finally also strip the 'character-' prefix
choices = [c.replace("character-attributes-", "") for c in choices]
self.websocket_handler.queue_put({
"type": "character_creator",
"action": "send_templates",
"templates": sorted(choices),
"content_context": self.websocket_handler.config["creator"]["content_context"],
})
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "send_templates",
"templates": sorted(choices),
"content_context": self.websocket_handler.config["creator"][
"content_context"
],
}
)
await asyncio.sleep(0.01)
def apply_step_data(self, data:dict):
def apply_step_data(self, data: dict):
step_data = StepData(**data)
if not self.character_creation_data:
self.character_creation_data = CharacterCreationData(
@@ -97,7 +96,7 @@ class CharacterCreatorServerPlugin:
is_player_character=step_data.is_player_character,
)
character=Character(
character = Character(
name="",
description="",
greeting_text="",
@@ -111,22 +110,25 @@ class CharacterCreatorServerPlugin:
character.gender = step_data.base_attributes.get("gender")
character.color = "red"
character.example_dialogue = step_data.dialogue_examples
return character, step_data
async def handle_submit_step2(self, data:dict):
async def handle_submit_step2(self, data: dict):
creator = self.scene.get_helper("creator").agent
character, step_data = self.apply_step_data(data)
self.emit_step_start(2)
def emit_attribute(name, value):
self.websocket_handler.queue_put({
"type": "character_creator",
"action": "base_attribute",
"name": name,
"value": value,
})
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "base_attribute",
"name": name,
"value": value,
}
)
base_attributes = await creator.create_character_attributes(
step_data.character_prompt,
step_data.template,
@@ -136,53 +138,54 @@ class CharacterCreatorServerPlugin:
custom_attributes=step_data.custom_attributes,
predefined_attributes=step_data.base_attributes,
)
base_attributes["scenario_context"] = step_data.scenario_context
character.base_attributes = base_attributes
character.gender = base_attributes["gender"]
character.name = base_attributes["name"]
log.info("Created character", name=base_attributes.get("name"))
self.emit_step_done(2)
async def handle_submit_step3(self, data:dict):
async def handle_submit_step3(self, data: dict):
creator = self.scene.get_helper("creator").agent
character, step_data = self.apply_step_data(data)
self.emit_step_start(3)
description = await creator.create_character_description(
character,
content_context=step_data.scenario_context,
)
character.description = description
self.websocket_handler.queue_put({
"type": "character_creator",
"action": "description",
"description": description,
})
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "description",
"description": description,
}
)
self.emit_step_done(3)
async def handle_submit_step4(self, data:dict):
async def handle_submit_step4(self, data: dict):
creator = self.scene.get_helper("creator").agent
character, step_data = self.apply_step_data(data)
def emit_detail(question, answer):
self.websocket_handler.queue_put({
"type": "character_creator",
"action": "detail",
"question": question,
"answer": answer,
})
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "detail",
"question": question,
"answer": answer,
}
)
self.emit_step_start(4)
character_details = await creator.create_character_details(
@@ -193,28 +196,30 @@ class CharacterCreatorServerPlugin:
content_context=self.character_creation_data.scenario_context,
)
character.details = list(character_details.values())
self.emit_step_done(4)
async def handle_submit_step5(self, data:dict):
async def handle_submit_step5(self, data: dict):
creator = self.scene.get_helper("creator").agent
character, step_data = self.apply_step_data(data)
self.websocket_handler.queue_put({
"type": "character_creator",
"action": "set_generating_step",
"step": 5,
})
def emit_example(key, example):
self.websocket_handler.queue_put({
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "example_dialogue",
"example": example,
})
"action": "set_generating_step",
"step": 5,
}
)
def emit_example(key, example):
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "example_dialogue",
"example": example,
}
)
dialogue_guide = await creator.create_character_example_dialogue(
character,
step_data.template,
@@ -223,63 +228,61 @@ class CharacterCreatorServerPlugin:
content_context=self.character_creation_data.scenario_context,
example_callback=emit_example,
)
character.dialogue_guide = dialogue_guide
self.emit_step_done(5)
async def handle_submit_step6(self, data:dict):
async def handle_submit_step6(self, data: dict):
character, step_data = self.apply_step_data(data)
# check if acter with character name already exists
for actor in self.scene.actors:
if actor.character.name == character.name:
if character.is_player and not actor.character.is_player:
log.info("Character already exists, but is not a player", name=character.name)
log.info(
"Character already exists, but is not a player",
name=character.name,
)
await self.scene.remove_actor(actor)
break
log.info("Character already exists", name=character.name)
actor.character = character
self.scene.emit_status()
self.emit_step_done(6)
return
if character.is_player:
actor = Player(character, self.scene.get_helper("conversation").agent)
else:
actor = Actor(character, self.scene.get_helper("conversation").agent)
log.info("Adding actor", name=character.name, actor=actor)
character.base_attributes["scenario_context"] = step_data.scenario_context
character.base_attributes["_template"] = step_data.template
character.base_attributes["_prompt"] = step_data.character_prompt
await self.scene.add_actor(actor)
self.scene.emit_status()
self.emit_step_done(6)
def emit_step_start(self, step):
self.websocket_handler.queue_put({
"type": "character_creator",
"action": "set_generating_step",
"step": step,
})
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "set_generating_step",
"step": step,
}
)
def emit_step_done(self, step):
self.websocket_handler.queue_put({
"type": "character_creator",
"action": "set_generating_step_done",
"step": step,
})
self.websocket_handler.queue_put(
{
"type": "character_creator",
"action": "set_generating_step_done",
"step": step,
}
)

View File

@@ -1,73 +1,80 @@
import os
import pydantic
import asyncio
import structlog
import json
import os
import pydantic
import structlog
from talemate.load import load_character_into_scene
log = structlog.get_logger("talemate.server.character_importer")
class ListCharactersData(pydantic.BaseModel):
scene_path: str
class ImportCharacterData(pydantic.BaseModel):
scene_path: str
character_name: str
class CharacterImporterServerPlugin:
router = "character_importer"
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
@property
def scene(self):
return self.websocket_handler.scene
async def handle(self, data:dict):
async def handle(self, data: dict):
log.info("Character importer action", action=data.get("action"))
fn = getattr(self, f"handle_{data.get('action')}", None)
if fn is None:
return
await fn(data)
async def handle_list_characters(self, data):
list_characters_data = ListCharactersData(**data)
scene_path = list_characters_data.scene_path
with open(scene_path, "r") as f:
scene_data = json.load(f)
self.websocket_handler.queue_put({
"type": "character_importer",
"action": "list_characters",
"characters": [character["name"] for character in scene_data.get("characters", [])]
})
self.websocket_handler.queue_put(
{
"type": "character_importer",
"action": "list_characters",
"characters": [
character["name"] for character in scene_data.get("characters", [])
],
}
)
await asyncio.sleep(0)
async def handle_import(self, data):
import_character_data = ImportCharacterData(**data)
scene = self.websocket_handler.scene
await load_character_into_scene(
scene,
import_character_data.scene_path,
import_character_data.character_name,
)
scene.emit_status()
self.websocket_handler.queue_put({
"type": "character_importer",
"action": "import_character_done",
})
self.websocket_handler.queue_put(
{
"type": "character_importer",
"action": "import_character_done",
}
)

View File

@@ -1,165 +1,190 @@
import pydantic
import structlog
from talemate import VERSION
from talemate.client.registry import CLIENT_CLASSES
from talemate.config import Config as AppConfigData, load_config, save_config
from talemate.client.model_prompts import model_prompt
from talemate.client.registry import CLIENT_CLASSES
from talemate.config import Config as AppConfigData
from talemate.config import load_config, save_config
from talemate.emit import emit
log = structlog.get_logger("talemate.server.config")
class ConfigPayload(pydantic.BaseModel):
config: AppConfigData
class DefaultCharacterPayload(pydantic.BaseModel):
name: str
gender: str
description: str
color: str = "#3362bb"
class SetLLMTemplatePayload(pydantic.BaseModel):
template_file: str
model: str
class DetermineLLMTemplatePayload(pydantic.BaseModel):
model: str
class ConfigPlugin:
router = "config"
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
async def handle(self, data:dict):
async def handle(self, data: dict):
log.info("Config action", action=data.get("action"))
fn = getattr(self, f"handle_{data.get('action')}", None)
if fn is None:
return
await fn(data)
async def handle_save(self, data):
app_config_data = ConfigPayload(**data)
current_config = load_config()
current_config.update(app_config_data.dict().get("config"))
save_config(current_config)
self.websocket_handler.config = current_config
self.websocket_handler.queue_put({
"type": "app_config",
"data": load_config(),
"version": VERSION
})
self.websocket_handler.queue_put({
"type": "config",
"action": "save_complete",
})
self.websocket_handler.queue_put(
{"type": "app_config", "data": load_config(), "version": VERSION}
)
self.websocket_handler.queue_put(
{
"type": "config",
"action": "save_complete",
}
)
async def handle_save_default_character(self, data):
log.info("Saving default character", data=data["data"])
payload = DefaultCharacterPayload(**data["data"])
current_config = load_config()
current_config["game"]["default_player_character"] = payload.model_dump()
log.info("Saving default character", character=current_config["game"]["default_player_character"])
log.info(
"Saving default character",
character=current_config["game"]["default_player_character"],
)
save_config(current_config)
self.websocket_handler.config = current_config
self.websocket_handler.queue_put({
"type": "app_config",
"data": load_config(),
"version": VERSION
})
self.websocket_handler.queue_put({
"type": "config",
"action": "save_default_character_complete",
})
self.websocket_handler.queue_put(
{"type": "app_config", "data": load_config(), "version": VERSION}
)
self.websocket_handler.queue_put(
{
"type": "config",
"action": "save_default_character_complete",
}
)
async def handle_request_std_llm_templates(self, data):
log.info("Requesting std llm templates")
self.websocket_handler.queue_put({
"type": "config",
"action": "std_llm_templates",
"data": {
"templates": model_prompt.std_templates,
self.websocket_handler.queue_put(
{
"type": "config",
"action": "std_llm_templates",
"data": {
"templates": model_prompt.std_templates,
},
}
})
)
async def handle_set_llm_template(self, data):
payload = SetLLMTemplatePayload(**data["data"])
copied_to = model_prompt.create_user_override(payload.template_file, payload.model)
log.info("Copied template", copied_to=copied_to, template=payload.template_file, model=payload.model)
prompt_template_example, prompt_template_file = model_prompt(payload.model, "sysmsg", "prompt<|BOT|>{LLM coercion}")
log.info("Prompt template example", prompt_template_example=prompt_template_example, prompt_template_file=prompt_template_file)
self.websocket_handler.queue_put({
"type": "config",
"action": "set_llm_template_complete",
"data": {
"prompt_template_example": prompt_template_example,
"has_prompt_template": True if prompt_template_example else False,
"template_file": prompt_template_file,
copied_to = model_prompt.create_user_override(
payload.template_file, payload.model
)
log.info(
"Copied template",
copied_to=copied_to,
template=payload.template_file,
model=payload.model,
)
prompt_template_example, prompt_template_file = model_prompt(
payload.model, "sysmsg", "prompt<|BOT|>{LLM coercion}"
)
log.info(
"Prompt template example",
prompt_template_example=prompt_template_example,
prompt_template_file=prompt_template_file,
)
self.websocket_handler.queue_put(
{
"type": "config",
"action": "set_llm_template_complete",
"data": {
"prompt_template_example": prompt_template_example,
"has_prompt_template": True if prompt_template_example else False,
"template_file": prompt_template_file,
},
}
})
)
async def handle_determine_llm_template(self, data):
payload = DetermineLLMTemplatePayload(**data["data"])
log.info("Determining LLM template", model=payload.model)
template = model_prompt.query_hf_for_prompt_template_suggestion(payload.model)
log.info("Template suggestion", template=template)
if not template:
emit("status", message="No template found for model", status="warning")
else:
await self.handle_set_llm_template({
"data": {
"template_file": template,
"model": payload.model,
await self.handle_set_llm_template(
{
"data": {
"template_file": template,
"model": payload.model,
}
}
})
self.websocket_handler.queue_put({
"type": "config",
"action": "determine_llm_template_complete",
"data": {
"template": template,
)
self.websocket_handler.queue_put(
{
"type": "config",
"action": "determine_llm_template_complete",
"data": {
"template": template,
},
}
})
)
async def handle_request_client_types(self, data):
log.info("Requesting client types")
clients = {
client_type: CLIENT_CLASSES[client_type].Meta().model_dump() for client_type in CLIENT_CLASSES
client_type: CLIENT_CLASSES[client_type].Meta().model_dump()
for client_type in CLIENT_CLASSES
}
self.websocket_handler.queue_put({
"type": "config",
"action": "client_types",
"data": clients,
})
self.websocket_handler.queue_put(
{
"type": "config",
"action": "client_types",
"data": clients,
}
)

View File

@@ -1,7 +1,8 @@
import uuid
from typing import Any, Union
import pydantic
import structlog
from typing import Union, Any
import uuid
from talemate.config import load_config, save_config
@@ -11,43 +12,40 @@ log = structlog.get_logger("talemate.server.quick_settings")
class SetQuickSettingsPayload(pydantic.BaseModel):
setting: str
value: Any
class QuickSettingsPlugin:
router = "quick_settings"
@property
def scene(self):
return self.websocket_handler.scene
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
async def handle(self, data:dict):
async def handle(self, data: dict):
log.info("quick settings action", action=data.get("action"))
fn = getattr(self, f"handle_{data.get('action')}", None)
if fn is None:
return
await fn(data)
async def handle_set(self, data:dict):
async def handle_set(self, data: dict):
payload = SetQuickSettingsPayload(**data)
if payload.setting == "auto_save":
self.scene.config["game"]["general"]["auto_save"] = payload.value
elif payload.setting == "auto_progress":
self.scene.config["game"]["general"]["auto_progress"] = payload.value
else:
raise NotImplementedError(f"Setting {payload.setting} not implemented.")
save_config(self.scene.config)
self.websocket_handler.queue_put({
"type": self.router,
"action": "set_done",
"data": payload.model_dump()
})
self.websocket_handler.queue_put(
{"type": self.router, "action": "set_done", "data": payload.model_dump()}
)

View File

@@ -1,8 +1,8 @@
import os
import argparse
import asyncio
import os
import sys
import structlog
import websockets
@@ -10,13 +10,16 @@ from talemate.server.api import websocket_endpoint
log = structlog.get_logger("talemate.server.run")
def run_server(args):
"""
Run the talemate web server using the provided arguments.
:param args: command line arguments parsed by argparse
"""
start_server = websockets.serve(websocket_endpoint, args.host, args.port, max_size=2 ** 23)
start_server = websockets.serve(
websocket_endpoint, args.host, args.port, max_size=2**23
)
asyncio.get_event_loop().run_until_complete(start_server)
log.info("talemate backend started", host=args.host, port=args.port)
asyncio.get_event_loop().run_forever()

View File

@@ -1,152 +1,166 @@
import os
import pydantic
import asyncio
import structlog
import json
import os
from typing import Union
import pydantic
import structlog
from talemate.load import load_character_into_scene
log = structlog.get_logger("talemate.server.character_importer")
class ListScenesData(pydantic.BaseModel):
scene_path: str
class CreateSceneData(pydantic.BaseModel):
name: Union[str, None] = None
description: Union[str, None] = None
intro: Union[str, None] = None
intro: Union[str, None] = None
content_context: Union[str, None] = None
prompt: Union[str, None] = None
class SceneCreatorServerPlugin:
router = "scene_creator"
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
@property
def scene(self):
return self.websocket_handler.scene
async def handle(self, data:dict):
async def handle(self, data: dict):
log.info("Scene importer action", action=data.get("action"))
fn = getattr(self, f"handle_{data.get('action')}", None)
if fn is None:
return
await fn(data)
async def handle_generate_description(self, data):
create_scene_data = CreateSceneData(**data)
scene = self.websocket_handler.scene
creator = scene.get_helper("creator").agent
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating",
}
)
description = await creator.create_scene_description(
create_scene_data.prompt,
create_scene_data.content_context,
)
log.info("Generated scene description", description=description)
scene.description = description
self.send_scene_update()
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating_done",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating_done",
}
)
async def handle_generate_name(self, data):
create_scene_data = CreateSceneData(**data)
scene = self.websocket_handler.scene
creator = scene.get_helper("creator").agent
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating",
}
)
name = await creator.create_scene_name(
create_scene_data.prompt,
create_scene_data.content_context,
scene.description,
)
log.info("Generated scene name", name=name)
scene.name = name
self.send_scene_update()
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating_done",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating_done",
}
)
async def handle_generate_intro(self, data):
create_scene_data = CreateSceneData(**data)
scene = self.websocket_handler.scene
creator = scene.get_helper("creator").agent
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating",
}
)
intro = await creator.create_scene_intro(
create_scene_data.prompt,
create_scene_data.content_context,
scene.description,
scene.name,
)
log.info("Generated scene intro", intro=intro)
scene.intro = intro
self.send_scene_update()
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating_done",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating_done",
}
)
async def handle_generate(self, data):
create_scene_data = CreateSceneData(**data)
scene = self.websocket_handler.scene
creator = scene.get_helper("creator").agent
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating",
}
)
description = await creator.create_scene_description(
create_scene_data.prompt,
create_scene_data.content_context,
)
log.info("Generated scene description", description=description)
name = await creator.create_scene_name(
@@ -154,55 +168,61 @@ class SceneCreatorServerPlugin:
create_scene_data.content_context,
description,
)
log.info("Generated scene name", name=name)
intro = await creator.create_scene_intro(
create_scene_data.prompt,
create_scene_data.content_context,
description,
name,
)
log.info("Generated scene intro", intro=intro)
scene.name = name
scene.description = description
scene.intro = intro
scene.scenario_context = create_scene_data.content_context
self.send_scene_update()
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "set_generating_done",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "set_generating_done",
}
)
async def handle_edit(self, data):
scene = self.websocket_handler.scene
create_scene_data = CreateSceneData(**data)
scene.description = create_scene_data.description
scene.name = create_scene_data.name
scene.intro = create_scene_data.intro
scene.scenario_context = create_scene_data.content_context
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "scene_saved",
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "scene_saved",
}
)
async def handle_load(self, data):
self.send_scene_update()
await asyncio.sleep(0)
def send_scene_update(self):
scene = self.websocket_handler.scene
self.websocket_handler.queue_put({
"type": "scene_creator",
"action": "scene_update",
"description": scene.description,
"name": scene.name,
"intro": scene.intro,
})
self.websocket_handler.queue_put(
{
"type": "scene_creator",
"action": "scene_update",
"description": scene.description,
"name": scene.name,
"intro": scene.intro,
}
)

View File

@@ -4,23 +4,21 @@ import talemate.instance as instance
log = structlog.get_logger("talemate.server.tts")
class TTSPlugin:
router = "tts"
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
self.tts = None
async def handle(self, data:dict):
async def handle(self, data: dict):
action = data.get("action")
if action == "test":
return await self.handle_test(data)
async def handle_test(self, data:dict):
async def handle_test(self, data: dict):
tts_agent = instance.get_agent("tts")
await tts_agent.generate("Welcome to talemate!")
await tts_agent.generate("Welcome to talemate!")

View File

@@ -2,24 +2,29 @@ import asyncio
import base64
import os
import traceback
import structlog
import talemate.instance as instance
from talemate import Helper, Scene
from talemate.config import load_config, save_config, SceneAssetUpload
from talemate.client.registry import CLIENT_CLASSES
from talemate.config import SceneAssetUpload, load_config, save_config
from talemate.emit import Emission, Receiver, abort_wait_for_input, emit
from talemate.files import list_scenes_directory
from talemate.load import load_scene, load_scene_from_data, load_scene_from_character_card
from talemate.load import (
load_scene,
load_scene_from_character_card,
load_scene_from_data,
)
from talemate.scene_assets import Asset
from talemate.client.registry import CLIENT_CLASSES
from talemate.server import character_creator
from talemate.server import character_importer
from talemate.server import scene_creator
from talemate.server import config
from talemate.server import world_state_manager
from talemate.server import quick_settings
from talemate.server import (
character_creator,
character_importer,
config,
quick_settings,
scene_creator,
world_state_manager,
)
log = structlog.get_logger("talemate.server.websocket_server")
@@ -27,8 +32,6 @@ AGENT_INSTANCES = {}
class WebsocketHandler(Receiver):
def __init__(self, socket, out_queue, llm_clients=dict()):
self.agents = {typ: {"name": typ} for typ in instance.agent_types()}
self.socket = socket
@@ -40,7 +43,7 @@ class WebsocketHandler(Receiver):
for name, agent_config in self.config.get("agents", {}).items():
self.agents[name] = agent_config
self.llm_clients = self.config.get("clients", llm_clients)
instance.get_agent("memory", self.scene)
@@ -48,16 +51,26 @@ class WebsocketHandler(Receiver):
# unconveniently named function, this `connect` method is called
# to connect signals handlers to the websocket handler
self.connect()
self.connect_llm_clients()
self.routes = {
character_creator.CharacterCreatorServerPlugin.router: character_creator.CharacterCreatorServerPlugin(self),
character_importer.CharacterImporterServerPlugin.router: character_importer.CharacterImporterServerPlugin(self),
scene_creator.SceneCreatorServerPlugin.router: scene_creator.SceneCreatorServerPlugin(self),
self.routes = {
character_creator.CharacterCreatorServerPlugin.router: character_creator.CharacterCreatorServerPlugin(
self
),
character_importer.CharacterImporterServerPlugin.router: character_importer.CharacterImporterServerPlugin(
self
),
scene_creator.SceneCreatorServerPlugin.router: scene_creator.SceneCreatorServerPlugin(
self
),
config.ConfigPlugin.router: config.ConfigPlugin(self),
world_state_manager.WorldStateManagerPlugin.router: world_state_manager.WorldStateManagerPlugin(self),
quick_settings.QuickSettingsPlugin.router: quick_settings.QuickSettingsPlugin(self),
world_state_manager.WorldStateManagerPlugin.router: world_state_manager.WorldStateManagerPlugin(
self
),
quick_settings.QuickSettingsPlugin.router: quick_settings.QuickSettingsPlugin(
self
),
}
# self.request_scenes_list()
@@ -85,34 +98,36 @@ class WebsocketHandler(Receiver):
log.error("Error connecting to client", client_name=client_name, e=e)
continue
log.info("Configured client", client_name=client_name, client_type=client.client_type)
log.info(
"Configured client",
client_name=client_name,
client_type=client.client_type,
)
self.connect_agents()
def connect_agents(self):
if not self.llm_clients:
instance.emit_agents_status()
return
for agent_typ, agent_config in self.agents.items():
try:
client = self.llm_clients.get(agent_config.get("client"))["client"]
except TypeError as e:
client = None
if not client:
# select first client
print("selecting first client", self.llm_clients)
client = list(self.llm_clients.values())[0]["client"]
agent_config["client"] = client.name
log.debug("Linked agent", agent_typ=agent_typ, client=client.name)
agent = instance.get_agent(agent_typ, client=client)
agent.client = client
agent.apply_config(**agent_config)
instance.emit_agents_status()
def init_scene(self):
@@ -127,20 +142,21 @@ class WebsocketHandler(Receiver):
log.debug("init agent", agent_typ=agent_typ, agent_config=agent_config)
agent = instance.get_agent(agent_typ, **agent_config)
#if getattr(agent, "client", None):
# if getattr(agent, "client", None):
# self.llm_clients[agent.client.name] = agent.client
scene.add_helper(Helper(agent))
return scene
async def load_scene(self, path_or_data, reset=False, callback=None, file_name=None):
async def load_scene(
self, path_or_data, reset=False, callback=None, file_name=None
):
try:
if self.scene:
instance.get_agent("memory").close_db(self.scene)
self.scene.disconnect()
scene = self.init_scene()
if not scene:
@@ -172,18 +188,17 @@ class WebsocketHandler(Receiver):
existing = set(self.llm_clients.keys())
self.llm_clients = {}
log.info("Configuring clients", clients=clients)
for client in clients:
client.pop("status", None)
client_cls = CLIENT_CLASSES.get(client["type"])
if not client_cls:
log.error("Client type not found", client=client)
continue
client_config = self.llm_clients[client["name"]] = {
"name": client["name"],
"type": client["type"],
@@ -208,17 +223,17 @@ class WebsocketHandler(Receiver):
for name in removed:
log.debug("Destroying client", name=name)
instance.destroy_client(name)
self.config["clients"] = self.llm_clients
self.connect_llm_clients()
save_config(self.config)
instance.sync_emit_clients_status()
def configure_agents(self, agents):
self.agents = {typ: {} for typ in instance.agent_types()}
log.debug("Configuring agents")
for agent in agents:
@@ -235,7 +250,7 @@ class WebsocketHandler(Receiver):
if getattr(agent_instance, "actions", None):
self.agents[name]["actions"] = agent.get("actions", {})
agent_instance.apply_config(**self.agents[name])
log.debug("Configured agent", name=name)
continue
@@ -253,16 +268,21 @@ class WebsocketHandler(Receiver):
agent_instance = instance.get_agent(name, **self.agents[name])
agent_instance.client = self.llm_clients[agent["client"]]["client"]
if agent_instance.has_toggle:
self.agents[name]["enabled"] = agent["enabled"]
if getattr(agent_instance, "actions", None):
self.agents[name]["actions"] = agent.get("actions", {})
agent_instance.apply_config(**self.agents[name])
log.debug("Configured agent", name=name, client_name=self.llm_clients[agent["client"]]["name"], client=self.llm_clients[agent["client"]]["client"])
log.debug(
"Configured agent",
name=name,
client_name=self.llm_clients[agent["client"]]["name"],
client=self.llm_clients[agent["client"]]["client"],
)
self.config["agents"] = self.agents
save_config(self.config)
@@ -300,16 +320,15 @@ class WebsocketHandler(Receiver):
"character": emission.character.name if emission.character else "",
}
)
def handle_director(self, emission: Emission):
if emission.character:
character = emission.character.name
elif emission.message_object.source:
character = emission.message_object.source
else:
character = ""
self.queue_put(
{
"type": "director",
@@ -381,7 +400,7 @@ class WebsocketHandler(Receiver):
"status": emission.status,
}
)
def handle_config_saved(self, emission: Emission):
self.queue_put(
{
@@ -389,7 +408,7 @@ class WebsocketHandler(Receiver):
"data": emission.data,
}
)
def handle_archived_history(self, emission: Emission):
self.queue_put(
{
@@ -523,7 +542,6 @@ class WebsocketHandler(Receiver):
def request_scenes_list(self, query: str = ""):
scenes_list = list_scenes_directory()
if query:
filtered_list = [
@@ -547,8 +565,10 @@ class WebsocketHandler(Receiver):
)
def request_scene_history(self):
history = [archived_history["text"] for archived_history in self.scene.archived_history]
history = [
archived_history["text"] for archived_history in self.scene.archived_history
]
self.queue_put(
{
"type": "scene_history",
@@ -558,14 +578,13 @@ class WebsocketHandler(Receiver):
async def request_client_status(self):
await instance.emit_clients_status()
def request_scene_assets(self, asset_ids:list[str]):
def request_scene_assets(self, asset_ids: list[str]):
scene_assets = self.scene.assets
for asset_id in asset_ids:
asset = scene_assets.get_asset_bytes_as_base64(asset_id)
self.queue_put(
{
"type": "scene_asset",
@@ -574,16 +593,16 @@ class WebsocketHandler(Receiver):
"media_type": scene_assets.get_asset(asset_id).media_type,
}
)
def request_assets(self, assets:list[dict]):
def request_assets(self, assets: list[dict]):
# way to request scene assets without loading the scene
#
# assets is a list of dicts with keys:
# path must be turned into absolute path
# path must begin with Scene.scenes_dir()
_assets = {}
for asset_dict in assets:
try:
asset_id, asset = self._asset(**asset_dict)
@@ -591,32 +610,37 @@ class WebsocketHandler(Receiver):
log.error("request_assets", error=traceback.format_exc(), **asset_dict)
continue
_assets[asset_id] = asset
self.queue_put(
{
"type": "assets",
"assets": _assets,
}
)
)
def _asset(self, path: str, **asset):
absolute_path = os.path.abspath(path)
if not absolute_path.startswith(Scene.scenes_dir()):
log.error("_asset", error="Invalid path", path=absolute_path, scenes_dir=Scene.scenes_dir())
log.error(
"_asset",
error="Invalid path",
path=absolute_path,
scenes_dir=Scene.scenes_dir(),
)
return
asset_path = os.path.join(os.path.dirname(absolute_path), "assets")
asset = Asset(**asset)
return asset.id, {
"base64": asset.to_base64(asset_path),
"media_type": asset.media_type,
}
def add_scene_asset(self, data:dict):
def add_scene_asset(self, data: dict):
asset_upload = SceneAssetUpload(**data)
asset = self.scene.assets.add_asset_from_image_data(asset_upload.content)
if asset_upload.scene_cover_image:
self.scene.assets.cover_image = asset.id
self.scene.emit_status()
@@ -624,9 +648,12 @@ class WebsocketHandler(Receiver):
character = self.scene.get_character(asset_upload.character_cover_image)
old_cover_image = character.cover_image
character.cover_image = asset.id
if not self.scene.assets.cover_image or old_cover_image == self.scene.assets.cover_image:
if (
not self.scene.assets.cover_image
or old_cover_image == self.scene.assets.cover_image
):
self.scene.assets.cover_image = asset.id
self.scene.emit_status()
self.scene.emit_status()
self.request_scene_assets([character.cover_image])
self.queue_put(
{
@@ -637,16 +664,14 @@ class WebsocketHandler(Receiver):
"character": character.name,
}
)
def delete_message(self, message_id):
self.scene.delete_message(message_id)
def edit_message(self, message_id, new_text):
self.scene.edit_message(message_id, new_text)
def apply_scene_config(self, scene_config:dict):
def apply_scene_config(self, scene_config: dict):
self.scene.apply_scene_config(scene_config)
self.queue_put(
{
@@ -654,28 +679,25 @@ class WebsocketHandler(Receiver):
"data": self.scene.scene_config,
}
)
def handle_character_card_upload(self, image_data_url:str, filename:str) -> str:
def handle_character_card_upload(self, image_data_url: str, filename: str) -> str:
image_type = image_data_url.split(";")[0].split(":")[1]
image_data = base64.b64decode(image_data_url.split(",")[1])
characters_path = os.path.join("./scenes", "characters")
filepath = os.path.join(characters_path, filename)
with open(filepath, "wb") as f:
f.write(image_data)
return filepath
async def route(self, data:dict):
async def route(self, data: dict):
route = data["type"]
if route not in self.routes:
return
plugin = self.routes[route]
try:
await plugin.handle(data)
@@ -687,4 +709,4 @@ class WebsocketHandler(Receiver):
"type": "error",
"error": str(e),
}
)
)

View File

@@ -1,22 +1,30 @@
import uuid
from typing import Any, Union
import pydantic
import structlog
from typing import Union, Any
import uuid
from talemate.world_state.manager import WorldStateManager, WorldStateTemplates, StateReinforcementTemplate
from talemate.world_state.manager import (
StateReinforcementTemplate,
WorldStateManager,
WorldStateTemplates,
)
log = structlog.get_logger("talemate.server.world_state_manager")
class UpdateCharacterAttributePayload(pydantic.BaseModel):
name: str
attribute: str
value: str
class UpdateCharacterDetailPayload(pydantic.BaseModel):
name: str
detail: str
value: str
class SetCharacterDetailReinforcementPayload(pydantic.BaseModel):
name: str
question: str
@@ -25,462 +33,545 @@ class SetCharacterDetailReinforcementPayload(pydantic.BaseModel):
answer: str = ""
update_state: bool = False
insert: str = "sequential"
class CharacterDetailReinforcementPayload(pydantic.BaseModel):
name: str
question: str
reset: bool = False
class SaveWorldEntryPayload(pydantic.BaseModel):
id:str
id: str
text: str
meta: dict = {}
class DeleteWorldEntryPayload(pydantic.BaseModel):
id: str
class SetWorldEntryReinforcementPayload(pydantic.BaseModel):
question: str
instructions: Union[str, None] = None
interval: int = 10
answer: str = ""
update_state: bool = False
insert: str = "never"
insert: str = "never"
class WorldEntryReinforcementPayload(pydantic.BaseModel):
question: str
reset: bool = False
class QueryContextDBPayload(pydantic.BaseModel):
query: str
meta: dict = {}
class UpdateContextDBPayload(pydantic.BaseModel):
text: str
meta: dict = {}
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
class DeleteContextDBPayload(pydantic.BaseModel):
id: Any
class UpdatePinPayload(pydantic.BaseModel):
entry_id: str
condition: Union[str, None] = None
condition_state: bool = False
active: bool = False
class RemovePinPayload(pydantic.BaseModel):
entry_id: str
entry_id: str
class SaveWorldStateTemplatePayload(pydantic.BaseModel):
template: StateReinforcementTemplate
class DeleteWorldStateTemplatePayload(pydantic.BaseModel):
template: StateReinforcementTemplate
class WorldStateManagerPlugin:
router = "world_state_manager"
@property
def scene(self):
return self.websocket_handler.scene
@property
def world_state_manager(self):
return WorldStateManager(self.scene)
def __init__(self, websocket_handler):
self.websocket_handler = websocket_handler
async def handle(self, data:dict):
async def handle(self, data: dict):
log.info("World state manager action", action=data.get("action"))
fn = getattr(self, f"handle_{data.get('action')}", None)
if fn is None:
return
await fn(data)
async def signal_operation_done(self):
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "operation_done",
"data": {}
})
self.websocket_handler.queue_put(
{"type": "world_state_manager", "action": "operation_done", "data": {}}
)
if self.scene.auto_save:
await self.scene.save(auto=True)
async def handle_get_character_list(self, data):
character_list = await self.world_state_manager.get_character_list()
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_list",
"data": character_list.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_list",
"data": character_list.model_dump(),
}
)
async def handle_get_character_details(self, data):
character_details = await self.world_state_manager.get_character_details(data["name"])
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_details",
"data": character_details.model_dump()
})
character_details = await self.world_state_manager.get_character_details(
data["name"]
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_details",
"data": character_details.model_dump(),
}
)
async def handle_get_world(self, data):
world = await self.world_state_manager.get_world()
log.debug("World", world=world)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "world",
"data": world.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "world",
"data": world.model_dump(),
}
)
async def handle_get_pins(self, data):
context_pins = await self.world_state_manager.get_pins()
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "pins",
"data": context_pins.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "pins",
"data": context_pins.model_dump(),
}
)
async def handle_get_templates(self, data):
templates = await self.world_state_manager.get_templates()
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "templates",
"data": templates.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "templates",
"data": templates.model_dump(),
}
)
async def handle_update_character_attribute(self, data):
payload = UpdateCharacterAttributePayload(**data)
await self.world_state_manager.update_character_attribute(payload.name, payload.attribute, payload.value)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_attribute_updated",
"data": payload.model_dump()
})
await self.world_state_manager.update_character_attribute(
payload.name, payload.attribute, payload.value
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_attribute_updated",
"data": payload.model_dump(),
}
)
# resend character details
await self.handle_get_character_details({"name":payload.name})
await self.handle_get_character_details({"name": payload.name})
await self.signal_operation_done()
async def handle_update_character_description(self, data):
payload = UpdateCharacterAttributePayload(**data)
await self.world_state_manager.update_character_description(payload.name, payload.value)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_description_updated",
"data": payload.model_dump()
})
# resend character details
await self.handle_get_character_details({"name":payload.name})
await self.signal_operation_done()
async def handle_update_character_detail(self, data):
payload = UpdateCharacterDetailPayload(**data)
await self.world_state_manager.update_character_detail(payload.name, payload.detail, payload.value)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_detail_updated",
"data": payload.model_dump()
})
await self.world_state_manager.update_character_description(
payload.name, payload.value
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_description_updated",
"data": payload.model_dump(),
}
)
# resend character details
await self.handle_get_character_details({"name":payload.name})
await self.handle_get_character_details({"name": payload.name})
await self.signal_operation_done()
async def handle_update_character_detail(self, data):
payload = UpdateCharacterDetailPayload(**data)
await self.world_state_manager.update_character_detail(
payload.name, payload.detail, payload.value
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_detail_updated",
"data": payload.model_dump(),
}
)
# resend character details
await self.handle_get_character_details({"name": payload.name})
await self.signal_operation_done()
async def handle_set_character_detail_reinforcement(self, data):
payload = SetCharacterDetailReinforcementPayload(**data)
await self.world_state_manager.add_detail_reinforcement(
payload.name,
payload.question,
payload.instructions,
payload.interval,
payload.name,
payload.question,
payload.instructions,
payload.interval,
payload.answer,
payload.insert,
payload.update_state
payload.update_state,
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_detail_reinforcement_set",
"data": payload.model_dump(),
}
)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_detail_reinforcement_set",
"data": payload.model_dump()
})
# resend character details
await self.handle_get_character_details({"name":payload.name})
await self.handle_get_character_details({"name": payload.name})
await self.signal_operation_done()
async def handle_run_character_detail_reinforcement(self, data):
payload = CharacterDetailReinforcementPayload(**data)
log.debug("Run character detail reinforcement", name=payload.name, question=payload.question, reset=payload.reset)
await self.world_state_manager.run_detail_reinforcement(
payload.name,
payload.question,
reset=payload.reset
log.debug(
"Run character detail reinforcement",
name=payload.name,
question=payload.question,
reset=payload.reset,
)
await self.world_state_manager.run_detail_reinforcement(
payload.name, payload.question, reset=payload.reset
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_detail_reinforcement_run",
"data": payload.model_dump(),
}
)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_detail_reinforcement_run",
"data": payload.model_dump()
})
# resend character details
await self.handle_get_character_details({"name":payload.name})
await self.handle_get_character_details({"name": payload.name})
await self.signal_operation_done()
async def handle_delete_character_detail_reinforcement(self, data):
payload = CharacterDetailReinforcementPayload(**data)
await self.world_state_manager.delete_detail_reinforcement(payload.name, payload.question)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "character_detail_reinforcement_deleted",
"data": payload.model_dump()
})
# resend character details
await self.handle_get_character_details({"name":payload.name})
await self.signal_operation_done()
await self.world_state_manager.delete_detail_reinforcement(
payload.name, payload.question
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "character_detail_reinforcement_deleted",
"data": payload.model_dump(),
}
)
# resend character details
await self.handle_get_character_details({"name": payload.name})
await self.signal_operation_done()
async def handle_save_world_entry(self, data):
payload = SaveWorldEntryPayload(**data)
log.debug("Save world entry", id=payload.id, text=payload.text, meta=payload.meta)
await self.world_state_manager.save_world_entry(payload.id, payload.text, payload.meta)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "world_entry_saved",
"data": payload.model_dump()
})
log.debug(
"Save world entry", id=payload.id, text=payload.text, meta=payload.meta
)
await self.world_state_manager.save_world_entry(
payload.id, payload.text, payload.meta
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "world_entry_saved",
"data": payload.model_dump(),
}
)
await self.handle_get_world({})
await self.signal_operation_done()
self.scene.world_state.emit()
async def handle_delete_world_entry(self, data):
payload = DeleteWorldEntryPayload(**data)
log.debug("Delete world entry", id=payload.id)
await self.world_state_manager.delete_context_db_entry(payload.id)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "world_entry_deleted",
"data": payload.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "world_entry_deleted",
"data": payload.model_dump(),
}
)
await self.handle_get_world({})
await self.signal_operation_done()
self.scene.world_state.emit()
self.scene.emit_status()
async def handle_set_world_state_reinforcement(self, data):
payload = SetWorldEntryReinforcementPayload(**data)
log.debug("Set world state reinforcement", question=payload.question, instructions=payload.instructions, interval=payload.interval, answer=payload.answer, insert=payload.insert, update_state=payload.update_state)
log.debug(
"Set world state reinforcement",
question=payload.question,
instructions=payload.instructions,
interval=payload.interval,
answer=payload.answer,
insert=payload.insert,
update_state=payload.update_state,
)
await self.world_state_manager.add_detail_reinforcement(
None,
payload.question,
payload.instructions,
payload.interval,
payload.question,
payload.instructions,
payload.interval,
payload.answer,
payload.insert,
payload.update_state
payload.update_state,
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "world_state_reinforcement_set",
"data": payload.model_dump(),
}
)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "world_state_reinforcement_set",
"data": payload.model_dump()
})
# resend world
await self.handle_get_world({})
await self.signal_operation_done()
async def handle_run_world_state_reinforcement(self, data):
payload = WorldEntryReinforcementPayload(**data)
await self.world_state_manager.run_detail_reinforcement(None, payload.question, payload.reset)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "world_state_reinforcement_ran",
"data": payload.model_dump()
})
await self.world_state_manager.run_detail_reinforcement(
None, payload.question, payload.reset
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "world_state_reinforcement_ran",
"data": payload.model_dump(),
}
)
# resend world
await self.handle_get_world({})
await self.signal_operation_done()
async def handle_delete_world_state_reinforcement(self, data):
payload = WorldEntryReinforcementPayload(**data)
await self.world_state_manager.delete_detail_reinforcement(None, payload.question)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "world_state_reinforcement_deleted",
"data": payload.model_dump()
})
await self.world_state_manager.delete_detail_reinforcement(
None, payload.question
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "world_state_reinforcement_deleted",
"data": payload.model_dump(),
}
)
# resend world
await self.handle_get_world({})
await self.signal_operation_done()
async def handle_query_context_db(self, data):
payload = QueryContextDBPayload(**data)
log.debug("Query context db", query=payload.query, meta=payload.meta)
context_db = await self.world_state_manager.get_context_db_entries(payload.query, **payload.meta)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "context_db_result",
"data": context_db.model_dump()
})
context_db = await self.world_state_manager.get_context_db_entries(
payload.query, **payload.meta
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "context_db_result",
"data": context_db.model_dump(),
}
)
await self.signal_operation_done()
async def handle_update_context_db(self, data):
payload = UpdateContextDBPayload(**data)
log.debug("Update context db", text=payload.text, meta=payload.meta, id=payload.id)
await self.world_state_manager.update_context_db_entry(payload.id, payload.text, payload.meta)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "context_db_updated",
"data": payload.model_dump()
})
log.debug(
"Update context db", text=payload.text, meta=payload.meta, id=payload.id
)
await self.world_state_manager.update_context_db_entry(
payload.id, payload.text, payload.meta
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "context_db_updated",
"data": payload.model_dump(),
}
)
await self.signal_operation_done()
async def handle_delete_context_db(self, data):
payload = DeleteContextDBPayload(**data)
log.debug("Delete context db", id=payload.id)
await self.world_state_manager.delete_context_db_entry(payload.id)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "context_db_deleted",
"data": payload.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "context_db_deleted",
"data": payload.model_dump(),
}
)
await self.signal_operation_done()
async def handle_set_pin(self, data):
payload = UpdatePinPayload(**data)
log.debug("Set pin", entry_id=payload.entry_id, condition=payload.condition, condition_state=payload.condition_state, active=payload.active)
await self.world_state_manager.set_pin(payload.entry_id, payload.condition, payload.condition_state, payload.active)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "pin_set",
"data": payload.model_dump()
})
log.debug(
"Set pin",
entry_id=payload.entry_id,
condition=payload.condition,
condition_state=payload.condition_state,
active=payload.active,
)
await self.world_state_manager.set_pin(
payload.entry_id, payload.condition, payload.condition_state, payload.active
)
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "pin_set",
"data": payload.model_dump(),
}
)
await self.handle_get_pins({})
await self.signal_operation_done()
await self.scene.load_active_pins()
self.scene.emit_status()
async def handle_remove_pin(self, data):
payload = RemovePinPayload(**data)
log.debug("Remove pin", entry_id=payload.entry_id)
await self.world_state_manager.remove_pin(payload.entry_id)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "pin_removed",
"data": payload.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "pin_removed",
"data": payload.model_dump(),
}
)
await self.handle_get_pins({})
await self.signal_operation_done()
await self.scene.load_active_pins()
self.scene.emit_status()
async def handle_save_template(self, data):
payload = SaveWorldStateTemplatePayload(**data)
log.debug("Save world state template", template=payload.template)
await self.world_state_manager.save_template(payload.template)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "template_saved",
"data": payload.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "template_saved",
"data": payload.model_dump(),
}
)
await self.handle_get_templates({})
await self.signal_operation_done()
async def handle_delete_template(self, data):
payload = DeleteWorldStateTemplatePayload(**data)
template = payload.template
log.debug("Delete world state template", template=template.name, template_type=template.type)
log.debug(
"Delete world state template",
template=template.name,
template_type=template.type,
)
await self.world_state_manager.remove_template(template.type, template.name)
self.websocket_handler.queue_put({
"type": "world_state_manager",
"action": "template_deleted",
"data": payload.model_dump()
})
self.websocket_handler.queue_put(
{
"type": "world_state_manager",
"action": "template_deleted",
"data": payload.model_dump(),
}
)
await self.handle_get_templates({})
await self.signal_operation_done()
await self.signal_operation_done()

View File

@@ -1,6 +1,7 @@
from talemate.emit import emit
import structlog
from talemate.emit import emit
__all__ = [
"set_loading",
"LoadingStatus",
@@ -10,11 +11,10 @@ log = structlog.get_logger("talemate.status")
class set_loading:
def __init__(self, message, set_busy:bool=True):
def __init__(self, message, set_busy: bool = True):
self.message = message
self.set_busy = set_busy
def __call__(self, fn):
async def wrapper(*args, **kwargs):
if self.set_busy:
@@ -23,15 +23,19 @@ class set_loading:
return await fn(*args, **kwargs)
finally:
emit("status", message="", status="idle")
return wrapper
class LoadingStatus:
def __init__(self, max_steps:int):
def __init__(self, max_steps: int):
self.max_steps = max_steps
self.current_step = 0
def __call__(self, message:str):
def __call__(self, message: str):
self.current_step += 1
emit("status", message=f"{message} [{self.current_step}/{self.max_steps}]", status="busy")
emit(
"status",
message=f"{message} [{self.current_step}/{self.max_steps}]",
status="busy",
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,23 +1,26 @@
import base64
import datetime
import io
import json
import re
import textwrap
import structlog
import isodate
import datetime
from typing import List, Union
from thefuzz import fuzz
import isodate
import structlog
from colorama import Back, Fore, Style, init
from PIL import Image
from nltk.tokenize import sent_tokenize
from PIL import Image
from thefuzz import fuzz
from talemate.scene_message import SceneMessage
log = structlog.get_logger("talemate.util")
# Initialize colorama
init(autoreset=True)
def fix_unquoted_keys(s):
unquoted_key_pattern = r"(?<!\\)(?:(?<=\{)|(?<=,))\s*(\w+)\s*:"
fixed_string = re.sub(
@@ -279,25 +282,25 @@ def replace_conditional(input_string: str, params) -> str:
return modified_string
def strip_partial_sentences(text:str) -> str:
def strip_partial_sentences(text: str) -> str:
# Sentence ending characters
sentence_endings = ['.', '!', '?', '"', "*"]
sentence_endings = [".", "!", "?", '"', "*"]
if not text:
return text
# Check if the last character is already a sentence ending
if text[-1] in sentence_endings:
return text
# Split the text into words
words = text.split()
# Iterate over the words in reverse order until a sentence ending is found
for i in range(len(words) - 1, -1, -1):
if words[i][-1] in sentence_endings:
return ' '.join(words[:i+1])
return " ".join(words[: i + 1])
# If no sentence ending is found, return the original text
return text
@@ -335,8 +338,8 @@ def clean_message(message: str) -> str:
message = message.replace("[", "*").replace("]", "*")
return message
def clean_dialogue(dialogue: str, main_name: str) -> str:
# re split by \n{not main_name}: with a max count of 1
pattern = r"\n(?!{}:).*".format(re.escape(main_name))
@@ -346,10 +349,11 @@ def clean_dialogue(dialogue: str, main_name: str) -> str:
dialogue = f"{main_name}: {dialogue}"
return clean_message(strip_partial_sentences(dialogue))
def clean_id(name: str) -> str:
"""
Cleans up a id name by removing all characters that aren't a-zA-Z0-9_-
Cleans up a id name by removing all characters that aren't a-zA-Z0-9_-
Spaces are allowed.
@@ -361,9 +365,10 @@ def clean_id(name: str) -> str:
"""
# Remove all characters that aren't a-zA-Z0-9_-
cleaned_name = re.sub(r"[^a-zA-Z0-9_\- ]", "", name)
return cleaned_name
def duration_to_timedelta(duration):
"""Convert an isodate.Duration object or a datetime.timedelta object to a datetime.timedelta object."""
# Check if the duration is already a timedelta object
@@ -375,6 +380,7 @@ def duration_to_timedelta(duration):
seconds = duration.tdelta.seconds
return datetime.timedelta(days=days, seconds=seconds)
def timedelta_to_duration(delta):
"""Convert a datetime.timedelta object to an isodate.Duration object."""
# Extract days and convert to years, months, and days
@@ -389,7 +395,15 @@ def timedelta_to_duration(delta):
seconds %= 3600
minutes = seconds // 60
seconds %= 60
return isodate.Duration(years=years, months=months, days=days, hours=hours, minutes=minutes, seconds=seconds)
return isodate.Duration(
years=years,
months=months,
days=days,
hours=hours,
minutes=minutes,
seconds=seconds,
)
def parse_duration_to_isodate_duration(duration_str):
"""Parse ISO 8601 duration string and ensure the result is an isodate.Duration."""
@@ -398,6 +412,7 @@ def parse_duration_to_isodate_duration(duration_str):
return timedelta_to_duration(parsed_duration)
return parsed_duration
def iso8601_diff(duration_str1, duration_str2):
# Parse the ISO 8601 duration strings ensuring they are isodate.Duration objects
duration1 = parse_duration_to_isodate_duration(duration_str1)
@@ -409,14 +424,14 @@ def iso8601_diff(duration_str1, duration_str2):
# Calculate the difference
difference_timedelta = abs(timedelta1 - timedelta2)
# Convert back to Duration for further processing
difference = timedelta_to_duration(difference_timedelta)
return difference
def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
# Parse the ISO8601 duration string into an isodate duration object
if not isinstance(iso_duration, isodate.Duration):
duration = isodate.parse_duration(iso_duration)
@@ -463,7 +478,7 @@ def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
# Construct the human-readable string
if len(components) > 1:
last = components.pop()
human_str = ', '.join(components) + ' and ' + last
human_str = ", ".join(components) + " and " + last
elif components:
human_str = components[0]
else:
@@ -471,16 +486,17 @@ def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
return f"{human_str}{suffix}"
def iso8601_diff_to_human(start, end):
if not start or not end:
return ""
diff = iso8601_diff(start, end)
return iso8601_duration_to_human(diff)
def iso8601_add(date_a:str, date_b:str) -> str:
def iso8601_add(date_a: str, date_b: str) -> str:
"""
Adds two ISO 8601 durations together.
"""
@@ -488,80 +504,84 @@ def iso8601_add(date_a:str, date_b:str) -> str:
if not date_a or not date_b:
return "PT0S"
new_ts = isodate.parse_duration(date_a.strip()) + isodate.parse_duration(date_b.strip())
new_ts = isodate.parse_duration(date_a.strip()) + isodate.parse_duration(
date_b.strip()
)
return isodate.duration_isoformat(new_ts)
def iso8601_correct_duration(duration: str) -> str:
# Split the string into date and time components using 'T' as the delimiter
parts = duration.split("T")
# Handle the date component
date_component = parts[0]
time_component = ""
# If there's a time component, process it
if len(parts) > 1:
time_component = parts[1]
# Check if the time component has any date values (Y, M, D) and move them to the date component
for char in "YD": # Removed 'M' from this loop
if char in time_component:
index = time_component.index(char)
date_component += time_component[:index+1]
time_component = time_component[index+1:]
date_component += time_component[: index + 1]
time_component = time_component[index + 1 :]
# If the date component contains any time values (H, M, S), move them to the time component
for char in "HMS":
if char in date_component:
index = date_component.index(char)
time_component = date_component[index:] + time_component
date_component = date_component[:index]
# Combine the corrected date and time components
corrected_duration = date_component
if time_component:
corrected_duration += "T" + time_component
return corrected_duration
def fix_faulty_json(data: str) -> str:
# Fix missing commas
data = re.sub(r'}\s*{', '},{', data)
data = re.sub(r']\s*{', '],{', data)
data = re.sub(r'}\s*\[', '},{', data)
data = re.sub(r']\s*\[', '],[', data)
data = re.sub(r"}\s*{", "},{", data)
data = re.sub(r"]\s*{", "],{", data)
data = re.sub(r"}\s*\[", "},{", data)
data = re.sub(r"]\s*\[", "],[", data)
# Fix trailing commas
data = re.sub(r',\s*}', '}', data)
data = re.sub(r',\s*]', ']', data)
data = re.sub(r",\s*}", "}", data)
data = re.sub(r",\s*]", "]", data)
try:
json.loads(data)
except json.JSONDecodeError:
try:
json.loads(data+"}")
return data+"}"
except json.JSONDecodeError:
json.loads(data + "}")
return data + "}"
except json.JSONDecodeError:
try:
json.loads(data+"]")
return data+"]"
json.loads(data + "]")
return data + "]"
except json.JSONDecodeError:
return data
return data
def extract_json(s):
"""
Extracts a JSON string from the beginning of the input string `s`.
Parameters:
s (str): The input string containing a JSON string at the beginning.
Returns:
str: The extracted JSON string.
dict: The parsed JSON object.
Raises:
ValueError: If a valid JSON string is not found.
"""
@@ -571,138 +591,154 @@ def extract_json(s):
json_string_start = None
s = s.lstrip() # Strip white spaces and line breaks from the beginning
i = 0
log.debug("extract_json", s=s)
# Iterate through the string.
while i < len(s):
# Count the opening and closing curly brackets.
if s[i] == '{' or s[i] == '[':
if s[i] == "{" or s[i] == "[":
bracket_stack.append(s[i])
open_brackets += 1
if json_string_start is None:
json_string_start = i
elif s[i] == '}' or s[i] == ']':
elif s[i] == "}" or s[i] == "]":
bracket_stack
close_brackets += 1
# Check if the brackets match, indicating a complete JSON string.
if open_brackets == close_brackets:
json_string = s[json_string_start:i+1]
json_string = s[json_string_start : i + 1]
# Try to parse the JSON string.
return json_string, json.loads(json_string)
i += 1
if json_string_start is None:
raise ValueError("No JSON string found.")
json_string = s[json_string_start:]
while bracket_stack:
char = bracket_stack.pop()
if char == '{':
json_string += '}'
elif char == '[':
json_string += ']'
if char == "{":
json_string += "}"
elif char == "[":
json_string += "]"
json_object = json.loads(json_string)
return json_string, json_object
def similarity_score(line: str, lines: list[str], similarity_threshold: int = 95) -> tuple[bool, int, str]:
def similarity_score(
line: str, lines: list[str], similarity_threshold: int = 95
) -> tuple[bool, int, str]:
"""
Checks if a line is similar to any of the lines in the list of lines.
Arguments:
line (str): The line to check.
lines (list): The list of lines to check against.
threshold (int): The similarity threshold to use when comparing lines.
Returns:
bool: Whether a similar line was found.
int: The similarity score of the line. If no similar line was found, the highest similarity score is returned.
str: The similar line that was found. If no similar line was found, None is returned.
"""
highest_similarity = 0
for existing_line in lines:
similarity = fuzz.ratio(line, existing_line)
highest_similarity = max(highest_similarity, similarity)
#print("SIMILARITY", similarity, existing_line[:32]+"...")
# print("SIMILARITY", similarity, existing_line[:32]+"...")
if similarity >= similarity_threshold:
return True, similarity, existing_line
return False, highest_similarity, None
def dedupe_sentences(line_a:str, line_b:str, similarity_threshold:int=95, debug:bool=False, split_on_comma:bool=True) -> str:
def dedupe_sentences(
line_a: str,
line_b: str,
similarity_threshold: int = 95,
debug: bool = False,
split_on_comma: bool = True,
) -> str:
"""
Will split both lines into sentences and then compare each sentence in line_a
against similar sentences in line_b. If a similar sentence is found, it will be
removed from line_a.
The similarity threshold is used to determine if two sentences are similar.
Arguments:
line_a (str): The first line.
line_b (str): The second line.
similarity_threshold (int): The similarity threshold to use when comparing sentences.
debug (bool): Whether to log debug messages.
split_on_comma (bool): Whether to split line_b sentences on commas as well.
Returns:
str: the cleaned line_a.
"""
line_a_sentences = sent_tokenize(line_a)
line_b_sentences = sent_tokenize(line_b)
cleaned_line_a_sentences = []
if split_on_comma:
# collect all sentences from line_b that contain a comma
line_b_sentences_with_comma = []
for line_b_sentence in line_b_sentences:
if "," in line_b_sentence:
line_b_sentences_with_comma.append(line_b_sentence)
# then split all sentences in line_b_sentences_with_comma on the comma
# and extend line_b_sentences with the split sentences, making sure
# to strip whitespace from the beginning and end of each sentence
for line_b_sentence in line_b_sentences_with_comma:
line_b_sentences.extend([s.strip() for s in line_b_sentence.split(",")])
for line_a_sentence in line_a_sentences:
similar_found = False
for line_b_sentence in line_b_sentences:
similarity = fuzz.ratio(line_a_sentence, line_b_sentence)
if similarity >= similarity_threshold:
if debug:
log.debug("DEDUPE SENTENCE", similarity=similarity, line_a_sentence=line_a_sentence, line_b_sentence=line_b_sentence)
log.debug(
"DEDUPE SENTENCE",
similarity=similarity,
line_a_sentence=line_a_sentence,
line_b_sentence=line_b_sentence,
)
similar_found = True
break
if not similar_found:
cleaned_line_a_sentences.append(line_a_sentence)
return " ".join(cleaned_line_a_sentences)
def dedupe_string_old(s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False) -> str:
def dedupe_string_old(
s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False
) -> str:
"""
Removes duplicate lines from a string.
Arguments:
s (str): The input string.
min_length (int): The minimum length of a line to be checked for duplicates.
similarity_threshold (int): The similarity threshold to use when comparing lines.
debug (bool): Whether to log debug messages.
Returns:
str: The deduplicated string.
"""
lines = s.split("\n")
deduped = []
for line in lines:
stripped_line = line.strip()
if len(stripped_line) > min_length:
@@ -712,33 +748,40 @@ def dedupe_string_old(s: str, min_length: int = 32, similarity_threshold: int =
if similarity >= similarity_threshold:
similar_found = True
if debug:
log.debug("DEDUPE", similarity=similarity, line=line, existing_line=existing_line)
log.debug(
"DEDUPE",
similarity=similarity,
line=line,
existing_line=existing_line,
)
break
if not similar_found:
deduped.append(line)
else:
deduped.append(line) # Allow shorter strings without dupe check
return "\n".join(deduped)
def dedupe_string(s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False) -> str:
def dedupe_string(
s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False
) -> str:
"""
Removes duplicate lines from a string going from the bottom up.
Arguments:
s (str): The input string.
min_length (int): The minimum length of a line to be checked for duplicates.
similarity_threshold (int): The similarity threshold to use when comparing lines.
debug (bool): Whether to log debug messages.
Returns:
str: The deduplicated string.
"""
lines = s.split("\n")
deduped = []
for line in reversed(lines):
stripped_line = line.strip()
if len(stripped_line) > min_length:
@@ -748,36 +791,42 @@ def dedupe_string(s: str, min_length: int = 32, similarity_threshold: int = 95,
if similarity >= similarity_threshold:
similar_found = True
if debug:
log.debug("DEDUPE", similarity=similarity, line=line, existing_line=existing_line)
log.debug(
"DEDUPE",
similarity=similarity,
line=line,
existing_line=existing_line,
)
break
if not similar_found:
deduped.append(line)
else:
deduped.append(line) # Allow shorter strings without dupe check
return "\n".join(reversed(deduped))
def remove_extra_linebreaks(s: str) -> str:
"""
Removes extra line breaks from a string.
Parameters:
s (str): The input string.
Returns:
str: The string with extra line breaks removed.
"""
return re.sub(r"\n{3,}", "\n\n", s)
def replace_exposition_markers(s:str) -> str:
def replace_exposition_markers(s: str) -> str:
s = s.replace("(", "*").replace(")", "*")
s = s.replace("[", "*").replace("]", "*")
return s
return s
def ensure_dialog_format(line:str, talking_character:str=None) -> str:
#if "*" not in line and '"' not in line:
def ensure_dialog_format(line: str, talking_character: str = None) -> str:
# if "*" not in line and '"' not in line:
# if talking_character:
# line = line[len(talking_character)+1:].lstrip()
# return f"{talking_character}: \"{line}\""
@@ -785,7 +834,7 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
#
if talking_character:
line = line[len(talking_character)+1:].lstrip()
line = line[len(talking_character) + 1 :].lstrip()
lines = []
@@ -793,52 +842,53 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
try:
_line = ensure_dialog_line_format(_line)
except Exception as exc:
log.error("ensure_dialog_format", msg="Error ensuring dialog line format", line=_line, exc_info=exc)
log.error(
"ensure_dialog_format",
msg="Error ensuring dialog line format",
line=_line,
exc_info=exc,
)
pass
lines.append(_line)
if len(lines) > 1:
line = "\n".join(lines)
else:
line = lines[0]
if talking_character:
line = f"{talking_character}: {line}"
return line
def ensure_dialog_line_format(line:str):
return line
def ensure_dialog_line_format(line: str):
"""
a Python function that standardizes the formatting of dialogue and action/thought
descriptions in text strings. This function is intended for use in a text-based
a Python function that standardizes the formatting of dialogue and action/thought
descriptions in text strings. This function is intended for use in a text-based
game where spoken dialogue is encased in double quotes (" ") and actions/thoughts are
encased in asterisks (* *). The function must correctly format strings, ensuring that
each spoken sentence and action/thought is properly encased
"""
i = 0
segments = []
segment = None
segment_open = None
line = line.strip()
line = line.replace('"*', '"').replace('*"', '"')
for i in range(len(line)):
c = line[i]
#print("segment_open", segment_open)
#print("segment", segment)
if c in ['"', '*']:
# print("segment_open", segment_open)
# print("segment", segment)
if c in ['"', "*"]:
if segment_open == c:
# open segment is the same as the current character
# closing
@@ -849,15 +899,15 @@ def ensure_dialog_line_format(line:str):
elif segment_open is not None and segment_open != c:
# open segment is not the same as the current character
# opening - close the current segment and open a new one
# if we are at the last character we append the segment
if i == len(line)-1 and segment.strip():
if i == len(line) - 1 and segment.strip():
segment += c
segments += [segment.strip()]
segment_open = None
segment = None
continue
segments += [segment.strip()]
segment_open = c
segment = c
@@ -871,27 +921,27 @@ def ensure_dialog_line_format(line:str):
segment = c
else:
segment += c
if segment is not None:
if segment.strip().strip("*").strip('"'):
segments += [segment.strip()]
for i in range(len(segments)):
segment = segments[i]
if segment in ['"', '*']:
if segment in ['"', "*"]:
if i > 0:
prev_segment = segments[i-1]
if prev_segment and prev_segment[-1] not in ['"', '*']:
segments[i-1] = f"{prev_segment}{segment}"
prev_segment = segments[i - 1]
if prev_segment and prev_segment[-1] not in ['"', "*"]:
segments[i - 1] = f"{prev_segment}{segment}"
segments[i] = ""
continue
for i in range(len(segments)):
segment = segments[i]
if not segment:
continue
if segment[0] == "*" and segment[-1] != "*":
segment += "*"
elif segment[-1] == "*" and segment[0] != "*":
@@ -900,43 +950,42 @@ def ensure_dialog_line_format(line:str):
segment += '"'
elif segment[-1] == '"' and segment[0] != '"':
segment = '"' + segment
elif segment[0] in ['"', '*'] and segment[-1] == segment[0]:
elif segment[0] in ['"', "*"] and segment[-1] == segment[0]:
continue
segments[i] = segment
for i in range(len(segments)):
segment = segments[i]
if not segment or segment[0] in ['"', '*']:
if not segment or segment[0] in ['"', "*"]:
continue
prev_segment = segments[i-1] if i > 0 else None
next_segment = segments[i+1] if i < len(segments)-1 else None
prev_segment = segments[i - 1] if i > 0 else None
next_segment = segments[i + 1] if i < len(segments) - 1 else None
if prev_segment and prev_segment[-1] == '"':
segments[i] = f"*{segment}*"
elif prev_segment and prev_segment[-1] == '*':
segments[i] = f"\"{segment}\""
elif prev_segment and prev_segment[-1] == "*":
segments[i] = f'"{segment}"'
elif next_segment and next_segment[0] == '"':
segments[i] = f"*{segment}*"
elif next_segment and next_segment[0] == '*':
segments[i] = f"\"{segment}\""
elif next_segment and next_segment[0] == "*":
segments[i] = f'"{segment}"'
for i in range(len(segments)):
segments[i] = clean_uneven_markers(segments[i], '"')
segments[i] = clean_uneven_markers(segments[i], '*')
segments[i] = clean_uneven_markers(segments[i], "*")
final = " ".join(segment for segment in segments if segment).strip()
final = final.replace('","', '').replace('"."', '')
final = final.replace('","', "").replace('"."', "")
return final
def clean_uneven_markers(chunk:str, marker:str):
def clean_uneven_markers(chunk: str, marker: str):
# if there is an uneven number of quotes, remove the last one if its
# at the end of the chunk. If its in the middle, add a quote to the endc
count = chunk.count(marker)
if count % 2 == 1:
if chunk.endswith(marker):
chunk = chunk[:-1]
@@ -946,5 +995,5 @@ def clean_uneven_markers(chunk:str, marker:str):
chunk = chunk.replace(marker, "")
else:
chunk += marker
return chunk
return chunk

View File

@@ -1,31 +1,36 @@
from pydantic import BaseModel, Field, field_validator
from talemate.emit import emit
import structlog
import traceback
from typing import Union, Any
from enum import Enum
from typing import Any, Union
import structlog
from pydantic import BaseModel, Field, field_validator
import talemate.instance as instance
from talemate.prompts import Prompt
import talemate.automated_action as automated_action
import talemate.instance as instance
from talemate.emit import emit
from talemate.prompts import Prompt
ANY_CHARACTER = "__any_character__"
log = structlog.get_logger("talemate")
class CharacterState(BaseModel):
snapshot: Union[str, None] = None
emotion: Union[str, None] = None
class ObjectState(BaseModel):
snapshot: Union[str, None] = None
class InsertionMode(Enum):
sequential = "sequential"
conversation_context = "conversation-context"
all_context = "all-context"
never = "never"
class Reinforcement(BaseModel):
question: str
answer: Union[str, None] = None
@@ -34,11 +39,10 @@ class Reinforcement(BaseModel):
character: Union[str, None] = None
instructions: Union[str, None] = None
insert: str = "sequential"
@property
def as_context_line(self) -> str:
if self.character:
if self.question.strip().endswith("?"):
return f"{self.character}: {self.question} {self.answer}"
else:
@@ -46,57 +50,61 @@ class Reinforcement(BaseModel):
if self.question.strip().endswith("?"):
return f"{self.question} {self.answer}"
return f"{self.question}: {self.answer}"
class ManualContext(BaseModel):
id: str
text: str
meta: dict[str, Any] = {}
class ContextPin(BaseModel):
entry_id: str
condition: Union[str, None] = None
condition_state: bool = False
active: bool = False
class WorldState(BaseModel):
# characters in the scene by name
characters: dict[str, CharacterState] = {}
# objects in the scene by name
items: dict[str, ObjectState] = {}
# location description
location: Union[str, None] = None
# reinforcers
reinforce: list[Reinforcement] = []
# pins
pins: dict[str, ContextPin] = {}
# manual context
manual_context: dict[str, ManualContext] = {}
@property
def agent(self):
return instance.get_agent("world_state")
@property
def scene(self):
return self.agent.scene
@property
def pretty_json(self):
return self.model_dump_json(indent=2)
@property
def as_list(self):
return self.render().as_list
def filter_reinforcements(self, character:str=ANY_CHARACTER, insert:list[str]=None) -> list[Reinforcement]:
def filter_reinforcements(
self, character: str = ANY_CHARACTER, insert: list[str] = None
) -> list[Reinforcement]:
"""
Returns a filtered list of Reinforcement objects based on character and insert criteria.
@@ -107,24 +115,23 @@ class WorldState(BaseModel):
"""
Returns a filtered set of results as list
"""
result = []
for reinforcement in self.reinforce:
if not reinforcement.answer:
continue
if character != ANY_CHARACTER and reinforcement.character != character:
continue
if insert and reinforcement.insert not in insert:
continue
result.append(reinforcement)
return result
def reset(self):
"""
Resets the WorldState instance to its initial state by clearing characters, items, and location.
@@ -134,8 +141,8 @@ class WorldState(BaseModel):
"""
self.characters = {}
self.items = {}
self.location = None
self.location = None
def emit(self, status="update"):
"""
Emits the current world state with the given status.
@@ -144,8 +151,8 @@ class WorldState(BaseModel):
- status: The status of the world state to emit, which influences the handling of the update event.
"""
emit("world_state", status=status, data=self.model_dump())
async def request_update(self, initial_only:bool=False):
async def request_update(self, initial_only: bool = False):
"""
Requests an update of the world state from the WorldState agent. If initial_only is true, emits current state without requesting if characters exist.
@@ -156,85 +163,97 @@ class WorldState(BaseModel):
if initial_only and self.characters:
self.emit()
return
# if auto is true, we need to check if agent has automatic update enabled
if initial_only and not self.agent.actions["update_world_state"].enabled:
self.emit()
return
self.emit(status="requested")
try:
world_state = await self.agent.request_world_state()
except Exception as e:
self.emit()
log.error("world_state.request_update", error=e, traceback=traceback.format_exc())
log.error(
"world_state.request_update", error=e, traceback=traceback.format_exc()
)
return
previous_characters = self.characters
previous_items = self.items
scene = self.agent.scene
character_names = scene.character_names
self.characters = {}
self.items = {}
for character_name, character in world_state.get("characters", {}).items():
# character name may not always come back exactly as we have
# it defined in the scene. We assign the correct name by checking occurences
# of both names in each other.
if character_name not in character_names:
for _character_name in character_names:
if _character_name.lower() in character_name.lower() or character_name.lower() in _character_name.lower():
log.debug("world_state adjusting character name", from_name=character_name, to_name=_character_name)
if (
_character_name.lower() in character_name.lower()
or character_name.lower() in _character_name.lower()
):
log.debug(
"world_state adjusting character name",
from_name=character_name,
to_name=_character_name,
)
character_name = _character_name
break
if not character:
continue
# if emotion is not set, see if a previous state exists
# and use that emotion
if "emotion" not in character:
log.debug("emotion not set", character_name=character_name, character=character, characters=previous_characters)
log.debug(
"emotion not set",
character_name=character_name,
character=character,
characters=previous_characters,
)
if character_name in previous_characters:
character["emotion"] = previous_characters[character_name].emotion
self.characters[character_name] = CharacterState(**character)
log.debug("world_state", character=character)
for item_name, item in world_state.get("items", {}).items():
if not item:
continue
self.items[item_name] = ObjectState(**item)
log.debug("world_state", item=item)
# deactivate persiting for now
# await self.persist()
self.emit()
async def persist(self):
"""
Persists the world state snapshots of characters and items into the memory agent.
TODO: neeeds re-thinking.
Its better to use state reinforcement to track states, persisting the small world
state snapshots most of the time does not have enough context to be useful.
Arguments:
- None
"""
memory = instance.get_agent("memory")
world_state = instance.get_agent("world_state")
# first we check if any of the characters were refered
# to with an alias
states = []
scene = self.agent.scene
@@ -247,10 +266,10 @@ class WorldState(BaseModel):
"typ": "world_state",
"character": character_name,
"ts": scene.ts,
}
},
}
)
for item_name in self.items.keys():
states.append(
{
@@ -260,18 +279,17 @@ class WorldState(BaseModel):
"typ": "world_state",
"item": item_name,
"ts": scene.ts,
}
},
}
)
log.debug("world_state.persist", states=states)
if not states:
return
await memory.add_many(states)
await memory.add_many(states)
async def request_update_inline(self):
"""
Requests an inline update of the world state from the WorldState agent and immediately emits the state.
@@ -279,22 +297,21 @@ class WorldState(BaseModel):
Arguments:
- None
"""
self.emit(status="requested")
world_state = await self.agent.request_world_state_inline()
self.emit()
async def add_reinforcement(
self,
question:str,
character:str=None,
instructions:str=None,
interval:int=10,
answer:str="",
insert:str="sequential",
self,
question: str,
character: str = None,
instructions: str = None,
interval: int = 10,
answer: str = "",
insert: str = "sequential",
) -> Reinforcement:
"""
Adds or updates a reinforcement in the world state. If a reinforcement with the same question and character exists, it is updated.
@@ -307,50 +324,59 @@ class WorldState(BaseModel):
- answer: The answer to the reinforcement question.
- insert: The method of inserting the reinforcement into the context.
"""
# if reinforcement already exists, update it
idx, reinforcement = await self.find_reinforcement(question, character)
if reinforcement:
# update the reinforcement object
reinforcement.instructions = instructions
reinforcement.interval = interval
reinforcement.answer = answer
old_insert_method = reinforcement.insert
reinforcement.insert = insert
# find the reinforcement message i nthe scene history and update the answer
if old_insert_method == "sequential":
message = self.agent.scene.find_message(typ="reinforcement", source=f"{question}:{character if character else ''}")
message = self.agent.scene.find_message(
typ="reinforcement",
source=f"{question}:{character if character else ''}",
)
if old_insert_method != insert and message:
# if it used to be sequential we need to remove its ReinforcmentMessage
# from the scene history
self.scene.pop_history(typ="reinforcement", source=message.source)
elif message:
message.message = answer
elif insert == "sequential":
# if it used to be something else and is now sequential, we need to run the state
# next loop
reinforcement.due = 0
# update the character detail if character name is specified
if character:
character = self.agent.scene.get_character(character)
await character.set_detail(question, answer)
return reinforcement
log.debug("world_state.add_reinforcement", question=question, character=character, instructions=instructions, interval=interval, answer=answer, insert=insert)
log.debug(
"world_state.add_reinforcement",
question=question,
character=character,
instructions=instructions,
interval=interval,
answer=answer,
insert=insert,
)
reinforcement = Reinforcement(
question=question,
character=character,
@@ -359,12 +385,12 @@ class WorldState(BaseModel):
answer=answer,
insert=insert,
)
self.reinforce.append(reinforcement)
return reinforcement
async def find_reinforcement(self, question:str, character:str=None):
async def find_reinforcement(self, question: str, character: str = None):
"""
Finds a reinforcement based on the question and character provided. Returns the index in the list and the reinforcement object.
@@ -373,11 +399,14 @@ class WorldState(BaseModel):
- character: The character to whom the reinforcement is linked. Use None for global reinforcements.
"""
for idx, reinforcement in enumerate(self.reinforce):
if reinforcement.question == question and reinforcement.character == character:
if (
reinforcement.question == question
and reinforcement.character == character
):
return idx, reinforcement
return None, None
def reinforcements_for_character(self, character:str):
def reinforcements_for_character(self, character: str):
"""
Returns a dictionary of reinforcements specifically for a given character.
@@ -385,13 +414,13 @@ class WorldState(BaseModel):
- character: The name of the character for whom reinforcements should be retrieved.
"""
reinforcements = {}
for reinforcement in self.reinforce:
if reinforcement.character == character:
reinforcements[reinforcement.question] = reinforcement
return reinforcements
def reinforcements_for_world(self):
"""
Returns a dictionary of global reinforcements not linked to any specific character.
@@ -400,56 +429,57 @@ class WorldState(BaseModel):
- None
"""
reinforcements = {}
for reinforcement in self.reinforce:
if not reinforcement.character:
reinforcements[reinforcement.question] = reinforcement
return reinforcements
async def remove_reinforcement(self, idx:int):
async def remove_reinforcement(self, idx: int):
"""
Removes a reinforcement from the world state.
Arguments:
- idx: The index of the reinforcement to remove.
"""
# find all instances of the reinforcement in the scene history
# and remove them
source=f"{self.reinforce[idx].question}:{self.reinforce[idx].character if self.reinforce[idx].character else ''}"
source = f"{self.reinforce[idx].question}:{self.reinforce[idx].character if self.reinforce[idx].character else ''}"
self.agent.scene.pop_history(typ="reinforcement", source=source, all=True)
self.reinforce.pop(idx)
def render(self):
"""
Renders the world state as a string.
"""
return Prompt.get(
"world_state.render",
vars={
"characters": self.characters,
"items": self.items,
"location": self.location,
}
},
)
async def commit_to_memory(self, memory_agent):
await memory_agent.add_many([
manual_context.model_dump() for manual_context in self.manual_context.values()
])
await memory_agent.add_many(
[
manual_context.model_dump()
for manual_context in self.manual_context.values()
]
)
def manual_context_for_world(self) -> dict[str, ManualContext]:
"""
Returns all manual context entries where meta["typ"] == "world_state"
"""
return {
manual_context.id: manual_context
for manual_context in self.manual_context.values()
if manual_context.meta.get("typ") == "world_state"
}
}

View File

@@ -1,65 +1,75 @@
from typing import TYPE_CHECKING, Any
import pydantic
import structlog
from talemate.config import StateReinforcementTemplate, WorldStateTemplates, save_config
from talemate.instance import get_agent
from talemate.config import WorldStateTemplates, StateReinforcementTemplate, save_config
from talemate.world_state import Reinforcement, ManualContext, ContextPin, InsertionMode
from talemate.world_state import ContextPin, InsertionMode, ManualContext, Reinforcement
if TYPE_CHECKING:
from talemate.tale_mate import Scene
log = structlog.get_logger("talemate.server.world_state_manager")
class CharacterSelect(pydantic.BaseModel):
name: str
active: bool = True
is_player: bool = False
class ContextDBEntry(pydantic.BaseModel):
text: str
meta: dict
id: Any
class ContextDB(pydantic.BaseModel):
entries: list[ContextDBEntry] = []
class CharacterDetails(pydantic.BaseModel):
name: str
active: bool = True
is_player: bool = False
description: str = ""
base_attributes: dict[str,str] = {}
details: dict[str,str] = {}
base_attributes: dict[str, str] = {}
details: dict[str, str] = {}
reinforcements: dict[str, Reinforcement] = {}
class World(pydantic.BaseModel):
entries: dict[str, ManualContext] = {}
reinforcements: dict[str, Reinforcement] = {}
class CharacterList(pydantic.BaseModel):
characters: dict[str, CharacterSelect] = {}
class HistoryEntry(pydantic.BaseModel):
text: str
start: int = None
end: int = None
ts: str = None
ts: str = None
class History(pydantic.BaseModel):
history: list[HistoryEntry] = []
class AnnotatedContextPin(pydantic.BaseModel):
pin: ContextPin
text: str
time_aware_text: str
class ContextPins(pydantic.BaseModel):
pins: dict[str, AnnotatedContextPin] = []
class WorldStateManager:
@property
def memory_agent(self):
"""
@@ -69,8 +79,8 @@ class WorldStateManager:
The memory agent instance responsible for managing memory-related operations.
"""
return get_agent("memory")
def __init__(self, scene:'Scene'):
def __init__(self, scene: "Scene"):
"""
Initializes the WorldStateManager with a given scene.
@@ -79,7 +89,7 @@ class WorldStateManager:
"""
self.scene = scene
self.world_state = scene.world_state
async def get_character_list(self) -> CharacterList:
"""
Retrieves a list of characters from the current scene.
@@ -87,18 +97,22 @@ class WorldStateManager:
Returns:
A CharacterList object containing the characters with their select properties from the scene.
"""
characters = CharacterList()
for character in self.scene.get_characters():
characters.characters[character.name] = CharacterSelect(name=character.name, active=True, is_player=character.is_player)
characters.characters[character.name] = CharacterSelect(
name=character.name, active=True, is_player=character.is_player
)
for character in self.scene.inactive_characters.values():
characters.characters[character.name] = CharacterSelect(name=character.name, active=False, is_player=character.is_player)
characters.characters[character.name] = CharacterSelect(
name=character.name, active=False, is_player=character.is_player
)
return characters
async def get_character_details(self, character_name:str) -> CharacterDetails:
async def get_character_details(self, character_name: str) -> CharacterDetails:
"""
Fetches and returns the details for a specific character by name.
@@ -108,21 +122,28 @@ class WorldStateManager:
Returns:
A CharacterDetails object containing the character's details, attributes, and reinforcements.
"""
character = self.scene.get_character(character_name)
details = CharacterDetails(name=character.name, active=True, description=character.description, is_player=character.is_player)
details = CharacterDetails(
name=character.name,
active=True,
description=character.description,
is_player=character.is_player,
)
for key, value in character.base_attributes.items():
details.base_attributes[key] = value
for key, value in character.details.items():
details.details[key] = value
details.reinforcements = self.world_state.reinforcements_for_character(character_name)
details.reinforcements = self.world_state.reinforcements_for_character(
character_name
)
return details
async def get_world(self) -> World:
"""
Retrieves the current state of the world, including entries and reinforcements.
@@ -132,10 +153,12 @@ class WorldStateManager:
"""
return World(
entries=self.world_state.manual_context_for_world(),
reinforcements=self.world_state.reinforcements_for_world()
reinforcements=self.world_state.reinforcements_for_world(),
)
async def get_context_db_entries(self, query:str, limit:int=20, **meta) -> ContextDB:
async def get_context_db_entries(
self, query: str, limit: int = 20, **meta
) -> ContextDB:
"""
Retrieves entries from the context database based on a query and metadata.
@@ -147,22 +170,24 @@ class WorldStateManager:
Returns:
A ContextDB object containing the found entries.
"""
if query.startswith("id:"):
_entries = await self.memory_agent.get_document(id=query[3:])
_entries = list(_entries.values())
else:
_entries = await self.memory_agent.multi_query([query], iterate=limit, max_tokens=9999999, **meta)
_entries = await self.memory_agent.multi_query(
[query], iterate=limit, max_tokens=9999999, **meta
)
entries = []
for entry in _entries:
entries.append(ContextDBEntry(text=entry.raw, meta=entry.meta, id=entry.id))
context_db = ContextDB(entries=entries)
return context_db
async def get_pins(self, active:bool=None) -> ContextPins:
async def get_pins(self, active: bool = None) -> ContextPins:
"""
Retrieves context pins that meet the specified activity condition.
@@ -172,31 +197,36 @@ class WorldStateManager:
Returns:
A ContextPins object containing the matching annotated context pins.
"""
pins = self.world_state.pins
candidates = [pin for pin in pins.values() if pin.active == active or active is None]
candidates = [
pin for pin in pins.values() if pin.active == active or active is None
]
_ids = [pin.entry_id for pin in candidates]
_pins = {}
documents = await self.memory_agent.get_document(id=_ids)
for pin in sorted(candidates, key=lambda x: x.active, reverse=True):
if pin.entry_id not in documents:
text = ""
time_aware_text = ""
else:
text = documents[pin.entry_id].raw
time_aware_text = str(documents[pin.entry_id])
annotated_pin = AnnotatedContextPin(pin=pin, text=text, time_aware_text=time_aware_text)
annotated_pin = AnnotatedContextPin(
pin=pin, text=text, time_aware_text=time_aware_text
)
_pins[pin.entry_id] = annotated_pin
return ContextPins(pins=_pins)
async def update_character_attribute(self, character_name:str, attribute:str, value:str):
async def update_character_attribute(
self, character_name: str, attribute: str, value: str
):
"""
Updates the attribute of a character to a new value.
@@ -207,8 +237,10 @@ class WorldStateManager:
"""
character = self.scene.get_character(character_name)
await character.set_base_attribute(attribute, value)
async def update_character_detail(self, character_name:str, detail:str, value:str):
async def update_character_detail(
self, character_name: str, detail: str, value: str
):
"""
Updates a specific detail of a character to a new value.
@@ -219,8 +251,8 @@ class WorldStateManager:
"""
character = self.scene.get_character(character_name)
await character.set_detail(detail, value)
async def update_character_description(self, character_name:str, description:str):
async def update_character_description(self, character_name: str, description: str):
"""
Updates the description of a character to a new value.
@@ -230,16 +262,16 @@ class WorldStateManager:
"""
character = self.scene.get_character(character_name)
await character.set_description(description)
async def add_detail_reinforcement(
self,
character_name:str,
question:str,
instructions:str=None,
interval:int=10,
answer:str="",
insert:str="sequential",
run_immediately:bool=False
self,
character_name: str,
question: str,
instructions: str = None,
interval: int = 10,
answer: str = "",
insert: str = "sequential",
run_immediately: bool = False,
) -> Reinforcement:
"""
Adds a detail reinforcement for a character with specified parameters.
@@ -262,16 +294,18 @@ class WorldStateManager:
reinforcement = await self.world_state.add_reinforcement(
question, character_name, instructions, interval, answer, insert
)
if run_immediately:
await world_state_agent.update_reinforcement(question, character_name)
else:
# if not running immediately, we need to emit the world state manually
self.world_state.emit()
return reinforcement
async def run_detail_reinforcement(self, character_name:str, question:str, reset:bool=False):
async def run_detail_reinforcement(
self, character_name: str, question: str, reset: bool = False
):
"""
Executes the detail reinforcement for a specific character and question.
@@ -280,9 +314,11 @@ class WorldStateManager:
question: The query/question that the reinforcement corresponds to.
"""
world_state_agent = get_agent("world_state")
await world_state_agent.update_reinforcement(question, character_name, reset=reset)
async def delete_detail_reinforcement(self, character_name:str, question:str):
await world_state_agent.update_reinforcement(
question, character_name, reset=reset
)
async def delete_detail_reinforcement(self, character_name: str, question: str):
"""
Deletes a detail reinforcement for a specified character and question.
@@ -290,13 +326,15 @@ class WorldStateManager:
character_name: The name of the character whose reinforcement is to be deleted.
question: The query/question of the reinforcement to be deleted.
"""
idx, reinforcement = await self.world_state.find_reinforcement(question, character_name)
idx, reinforcement = await self.world_state.find_reinforcement(
question, character_name
)
if idx is not None:
await self.world_state.remove_reinforcement(idx)
self.world_state.emit()
async def save_world_entry(self, entry_id:str, text:str, meta:dict):
async def save_world_entry(self, entry_id: str, text: str, meta: dict):
"""
Saves a manual world entry with specified text and metadata.
@@ -308,8 +346,8 @@ class WorldStateManager:
meta["source"] = "manual"
meta["typ"] = "world_state"
await self.update_context_db_entry(entry_id, text, meta)
async def update_context_db_entry(self, entry_id:str, text:str, meta:dict):
async def update_context_db_entry(self, entry_id: str, text: str, meta: dict):
"""
Updates an entry in the context database with new text and metadata.
@@ -318,47 +356,42 @@ class WorldStateManager:
text: The new text content for the world entry.
meta: A dictionary containing updated metadata for the world entry.
"""
if meta.get("source") == "manual":
# manual context needs to be updated in the world state
self.world_state.manual_context[entry_id] = ManualContext(
text=text,
meta=meta,
id=entry_id
text=text, meta=meta, id=entry_id
)
elif meta.get("typ") == "details":
# character detail needs to be mirrored to the
# character detail needs to be mirrored to the
# character object in the scene
character_name = meta.get("character")
character = self.scene.get_character(character_name)
character.details[meta.get("detail")] = text
await self.memory_agent.add_many([
{
"id": entry_id,
"text": text,
"meta": meta
}
])
async def delete_context_db_entry(self, entry_id:str):
await self.memory_agent.add_many([{"id": entry_id, "text": text, "meta": meta}])
async def delete_context_db_entry(self, entry_id: str):
"""
Deletes a specific entry from the context database using its identifier.
Arguments:
entry_id: The identifier of the world entry to be deleted.
"""
await self.memory_agent.delete({
"ids": entry_id
})
await self.memory_agent.delete({"ids": entry_id})
if entry_id in self.world_state.manual_context:
del self.world_state.manual_context[entry_id]
await self.remove_pin(entry_id)
async def set_pin(self, entry_id:str, condition:str=None, condition_state:bool=False, active:bool=False):
async def set_pin(
self,
entry_id: str,
condition: str = None,
condition_state: bool = False,
active: bool = False,
):
"""
Creates or updates a pin on a context entry with conditional activation.
@@ -368,33 +401,32 @@ class WorldStateManager:
condition_state: The boolean state that enables the pin; defaults to False.
active: A flag indicating whether the pin should be active; defaults to False.
"""
if not condition:
condition = None
condition_state = False
pin = ContextPin(
entry_id=entry_id,
condition=condition,
condition_state=condition_state,
active=active
active=active,
)
self.world_state.pins[entry_id] = pin
async def remove_all_empty_pins(self):
"""
Removes all pins that come back with empty `text` attributes from get_pins.
"""
pins = await self.get_pins()
for pin in pins.pins.values():
if not pin.text:
await self.remove_pin(pin.pin.entry_id)
async def remove_pin(self, entry_id:str):
async def remove_pin(self, entry_id: str):
"""
Removes an existing pin from a context entry using its identifier.
@@ -403,8 +435,7 @@ class WorldStateManager:
"""
if entry_id in self.world_state.pins:
del self.world_state.pins[entry_id]
async def get_templates(self) -> WorldStateTemplates:
"""
Retrieves the current world state templates from scene configuration.
@@ -415,9 +446,8 @@ class WorldStateManager:
templates = self.scene.config["game"]["world_state"]["templates"]
world_state_templates = WorldStateTemplates(**templates)
return world_state_templates
async def save_template(self, template:StateReinforcementTemplate):
async def save_template(self, template: StateReinforcementTemplate):
"""
Saves a state reinforcement template to the scene configuration.
@@ -428,17 +458,19 @@ class WorldStateManager:
If the template is set to auto-create, it will be applied immediately.
"""
config = self.scene.config
template_type = template.type
config["game"]["world_state"]["templates"][template_type][template.name] = template.model_dump()
config["game"]["world_state"]["templates"][template_type][
template.name
] = template.model_dump()
save_config(self.scene.config)
if template.auto_create:
await self.auto_apply_template(template)
async def remove_template(self, template_type:str, template_name:str):
async def remove_template(self, template_type: str, template_name: str):
"""
Removes a specific state reinforcement template from scene configuration.
@@ -450,14 +482,18 @@ class WorldStateManager:
If the specified template is not found, logs a warning.
"""
config = self.scene.config
try:
del config["game"]["world_state"]["templates"][template_type][template_name]
save_config(self.scene.config)
except KeyError:
log.warning("world state template not found", template_type=template_type, template_name=template_name)
log.warning(
"world state template not found",
template_type=template_type,
template_name=template_name,
)
pass
async def apply_all_auto_create_templates(self):
"""
Applies all auto-create state reinforcement templates.
@@ -467,18 +503,18 @@ class WorldStateManager:
"""
templates = self.scene.config["game"]["world_state"]["templates"]
world_state_templates = WorldStateTemplates(**templates)
candidates = []
for template in world_state_templates.state_reinforcement.values():
if template.auto_create:
candidates.append(template)
for template in candidates:
log.info("applying template", template=template)
await self.auto_apply_template(template)
async def auto_apply_template(self, template:StateReinforcementTemplate):
async def auto_apply_template(self, template: StateReinforcementTemplate):
"""
Automatically applies a state reinforcement template based on its type.
@@ -490,8 +526,10 @@ class WorldStateManager:
"""
fn = getattr(self, f"auto_apply_template_{template.type}")
await fn(template)
async def auto_apply_template_state_reinforcement(self, template:StateReinforcementTemplate):
async def auto_apply_template_state_reinforcement(
self, template: StateReinforcementTemplate
):
"""
Applies a state reinforcement template to characters based on the template's state type.
@@ -501,21 +539,27 @@ class WorldStateManager:
Note:
The characters to apply the template to are determined by the state_type in the template.
"""
characters = []
if template.state_type == "npc":
characters = [character.name for character in self.scene.get_npc_characters()]
characters = [
character.name for character in self.scene.get_npc_characters()
]
elif template.state_type == "character":
characters = [character.name for character in self.scene.get_characters()]
elif template.state_type == "player":
characters = [self.scene.get_player_character().name]
for character_name in characters:
await self.apply_template_state_reinforcement(template, character_name)
async def apply_template_state_reinforcement(self, template:StateReinforcementTemplate, character_name:str=None, run_immediately:bool=False) -> Reinforcement:
async def apply_template_state_reinforcement(
self,
template: StateReinforcementTemplate,
character_name: str = None,
run_immediately: bool = False,
) -> Reinforcement:
"""
Applies a state reinforcement template to a specific character, if provided.
@@ -530,22 +574,30 @@ class WorldStateManager:
Raises:
ValueError: If a character name is required but not provided.
"""
if not character_name and template.state_type in ["npc", "character", "player"]:
raise ValueError("Character name required for this template type.")
player_name = self.scene.get_player_character().name
formatted_query = template.query.format(character_name=character_name, player_name=player_name)
formatted_instructions = template.instructions.format(character_name=character_name, player_name=player_name) if template.instructions else None
formatted_query = template.query.format(
character_name=character_name, player_name=player_name
)
formatted_instructions = (
template.instructions.format(
character_name=character_name, player_name=player_name
)
if template.instructions
else None
)
if character_name:
details = await self.get_character_details(character_name)
# if reinforcement already exists, skip
if formatted_query in details.reinforcements:
return None
return await self.add_detail_reinforcement(
character_name,
formatted_query,
@@ -553,4 +605,4 @@ class WorldStateManager:
template.interval,
insert=template.insert,
run_immediately=run_immediately,
)
)

View File

@@ -6,13 +6,14 @@
<span class="headline">{{ title() }}</span>
</v-card-title>
<v-card-text>
<v-form ref="form" v-model="formIsValid">
<v-container>
<v-row>
<v-col cols="6">
<v-select v-model="client.type" :disabled="!typeEditable()" :items="clientChoices" label="Client Type" @update:model-value="resetToDefaults"></v-select>
</v-col>
<v-col cols="6">
<v-text-field v-model="client.name" label="Client Name"></v-text-field>
<v-text-field v-model="client.name" label="Client Name" :rules="[rules.required]"></v-text-field>
</v-col>
</v-row>
<v-row v-if="clientMeta().experimental">
@@ -24,7 +25,7 @@
<v-col cols="12">
<v-row>
<v-col :cols="clientMeta().enable_api_auth ? 7 : 12">
<v-text-field v-model="client.api_url" v-if="requiresAPIUrl(client)" label="API URL"></v-text-field>
<v-text-field v-model="client.api_url" v-if="requiresAPIUrl(client)" :rules="[rules.required]" label="API URL"></v-text-field>
</v-col>
<v-col cols="5">
<v-text-field type="password" v-model="client.api_key" v-if="requiresAPIUrl(client) && clientMeta().enable_api_auth" label="API Key"></v-text-field>
@@ -36,7 +37,7 @@
</v-row>
<v-row>
<v-col cols="4">
<v-text-field v-model="client.max_token_length" v-if="requiresAPIUrl(client)" type="number" label="Context Length"></v-text-field>
<v-text-field v-model="client.max_token_length" v-if="requiresAPIUrl(client)" type="number" label="Context Length" :rules="[rules.required]"></v-text-field>
</v-col>
<v-col cols="8" v-if="!typeEditable() && client.data && client.data.prompt_template_example !== null && client.model_name && clientMeta().requires_prompt_template">
<v-combobox ref="promptTemplateComboBox" label="Prompt Template" v-model="client.data.template_file" @update:model-value="setPromptTemplate" :items="promptTemplates"></v-combobox>
@@ -54,11 +55,12 @@
</v-col>
</v-row>
</v-container>
</v-form>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="primary" text @click="close" prepend-icon="mdi-cancel">Cancel</v-btn>
<v-btn color="primary" text @click="save" prepend-icon="mdi-check-circle-outline">Save</v-btn>
<v-btn color="primary" text @click="save" prepend-icon="mdi-check-circle-outline" :disabled="!formIsValid">Save</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -77,12 +79,19 @@ export default {
],
data() {
return {
formIsValid: false,
promptTemplates: [],
clientTypes: [],
clientChoices: [],
localDialog: this.state.dialog,
client: { ...this.state.currentClient },
defaultValuesByCLientType: {}
defaultValuesByCLientType: {},
rules: {
required: value => !!value || 'Field is required.',
},
rulesMaxTokenLength: [
v => !!v || 'Context length is required',
],
};
},
watch: {