mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-12-24 15:39:34 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb251d6e37 | ||
|
|
4ba635497b | ||
|
|
bdbf14c1ed | ||
|
|
c340fc085c | ||
|
|
94f8d0f242 | ||
|
|
1d8a9b113c | ||
|
|
1837796852 | ||
|
|
c5c53c056e | ||
|
|
f1b1190f0b |
67
README.md
67
README.md
@@ -7,13 +7,16 @@ Allows you to play roleplay scenarios with large language models.
|
||||
|------------------------------------------|------------------------------------------|
|
||||
|||
|
||||
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio.**
|
||||
> :warning: **It does not run any large language models itself but relies on existing APIs. Currently supports OpenAI, text-generation-webui and LMStudio. 0.18.0 also adds support for generic OpenAI api implementations, but generation quality on that will vary.**
|
||||
|
||||
This means you need to either have:
|
||||
- an [OpenAI](https://platform.openai.com/overview) api key
|
||||
- OR setup local (or remote via runpod) LLM inference via one of these options:
|
||||
- setup local (or remote via runpod) LLM inference via:
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
|
||||
- [LMStudio](https://lmstudio.ai/)
|
||||
- Any other OpenAI api implementation that implements the v1/completions endpoint
|
||||
- tested llamacpp with the `api_like_OAI.py` wrapper
|
||||
- let me know if you have tested any other implementations and they failed / worked or landed somewhere in between
|
||||
|
||||
## Current features
|
||||
|
||||
@@ -35,6 +38,7 @@ This means you need to either have:
|
||||
- Automatically keep track and reinforce selected character and world truths / states.
|
||||
- narrative tools
|
||||
- creative tools
|
||||
- manage multiple NPCs
|
||||
- AI backed character creation with template support (jinja2)
|
||||
- AI backed scenario creation
|
||||
- context managegement
|
||||
@@ -77,7 +81,7 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
|
||||
### Windows
|
||||
|
||||
1. Download and install Python 3.10 or Python 3.11 from the [official Python website](https://www.python.org/downloads/windows/). :warning: python3.12 is currently not supported.
|
||||
1. Download and install Node.js from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm.
|
||||
1. Download and install Node.js v20 from the [official Node.js website](https://nodejs.org/en/download/). This will also install npm. :warning: v21 is currently not supported.
|
||||
1. Download the Talemate project to your local machine. Download from [the Releases page](https://github.com/vegu-ai/talemate/releases).
|
||||
1. Unpack the download and run `install.bat` by double clicking it. This will set up the project on your local machine.
|
||||
1. Once the installation is complete, you can start the backend and frontend servers by running `start.bat`.
|
||||
@@ -87,45 +91,14 @@ There is also a [troubleshooting guide](docs/troubleshoot.md) that might help.
|
||||
|
||||
`python 3.10` or `python 3.11` is required. :warning: `python 3.12` not supported yet.
|
||||
|
||||
`nodejs v19 or v20` :warning: `v21` not supported yet.
|
||||
|
||||
1. `git clone git@github.com:vegu-ai/talemate`
|
||||
1. `cd talemate`
|
||||
1. `source install.sh`
|
||||
1. Start the backend: `python src/talemate/server/run.py runserver --host 0.0.0.0 --port 5050`.
|
||||
1. Open a new terminal, navigate to the `talemate_frontend` directory, and start the frontend server by running `npm run serve`.
|
||||
|
||||
## Configuration
|
||||
|
||||
### OpenAI
|
||||
|
||||
To set your openai api key, open `config.yaml` in any text editor and uncomment / add
|
||||
|
||||
```yaml
|
||||
openai:
|
||||
api_key: sk-my-api-key-goes-here
|
||||
```
|
||||
|
||||
You will need to restart the backend for this change to take effect.
|
||||
|
||||
### RunPod
|
||||
|
||||
To set your runpod api key, open `config.yaml` in any text editor and uncomment / add
|
||||
|
||||
```yaml
|
||||
runpod:
|
||||
api_key: my-api-key-goes-here
|
||||
```
|
||||
You will need to restart the backend for this change to take effect.
|
||||
|
||||
Once the api key is set Pods loaded from text-generation-webui templates (or the bloke's runpod llm template) will be autoamtically added to your client list in talemate.
|
||||
|
||||
**ATTENTION**: Talemate is not a suitable for way for you to determine whether your pod is currently running or not. **Always** check the runpod dashboard to see if your pod is running or not.
|
||||
|
||||
## Recommended Models
|
||||
(as of2023.10.25)
|
||||
|
||||
Any of the top models in any of the size classes here should work well:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/17fhp9k/huge_llm_comparisontest_39_models_tested_7b70b/
|
||||
|
||||
## Connecting to an LLM
|
||||
|
||||
On the right hand side click the "Add Client" button. If there is no button, you may need to toggle the client options by clicking this button:
|
||||
@@ -140,13 +113,33 @@ In the modal if you're planning to connect to text-generation-webui, you can lik
|
||||
|
||||

|
||||
|
||||
|
||||
#### Recommended Models
|
||||
|
||||
Any of the top models in any of the size classes here should work well (i wouldn't recommend going lower than 7B):
|
||||
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/18yp9u4/llm_comparisontest_api_edition_gpt4_vs_gemini_vs/
|
||||
|
||||
|
||||
### OpenAI
|
||||
|
||||
If you want to add an OpenAI client, just change the client type and select the apropriate model.
|
||||
|
||||

|
||||
|
||||
### Ready to go
|
||||
If you are setting this up for the first time, you should now see the client, but it will have a red dot next to it, stating that it requires an API key.
|
||||
|
||||

|
||||
|
||||
Click the `SET API KEY` button. This will open a modal where you can enter your API key.
|
||||
|
||||

|
||||
|
||||
Click `Save` and after a moment the client should have a green dot next to it, indicating that it is ready to go.
|
||||
|
||||

|
||||
|
||||
## Ready to go
|
||||
|
||||
You will know you are good to go when the client and all the agents have a green dot next to them.
|
||||
|
||||
|
||||
BIN
docs/img/0.18.0/openai-api-key-1.png
Normal file
BIN
docs/img/0.18.0/openai-api-key-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.6 KiB |
BIN
docs/img/0.18.0/openai-api-key-2.png
Normal file
BIN
docs/img/0.18.0/openai-api-key-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
docs/img/0.18.0/openai-api-key-3.png
Normal file
BIN
docs/img/0.18.0/openai-api-key-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.7 KiB |
@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.18.0"
|
||||
version = "0.18.2"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
|
||||
@@ -2,4 +2,4 @@ from .agents import Agent
|
||||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.18.0"
|
||||
VERSION = "0.18.2"
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from .base import Agent
|
||||
from .creator import CreatorAgent
|
||||
from .conversation import ConversationAgent
|
||||
from .creator import CreatorAgent
|
||||
from .director import DirectorAgent
|
||||
from .editor import EditorAgent
|
||||
from .memory import ChromaDBMemoryAgent, MemoryAgent
|
||||
from .narrator import NarratorAgent
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register
|
||||
from .summarize import SummarizeAgent
|
||||
from .editor import EditorAgent
|
||||
from .tts import TTSAgent
|
||||
from .world_state import WorldStateAgent
|
||||
from .tts import TTSAgent
|
||||
@@ -1,21 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import re
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
from blinker import signal
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.agents.context import ActiveAgent
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopStartEvent
|
||||
import talemate.emit.async_signals
|
||||
import dataclasses
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
@@ -37,26 +37,27 @@ class AgentActionConfig(pydantic.BaseModel):
|
||||
scope: str = "global"
|
||||
choices: Union[list[dict[str, str]], None] = None
|
||||
note: Union[str, None] = None
|
||||
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
|
||||
class AgentAction(pydantic.BaseModel):
|
||||
enabled: bool = True
|
||||
label: str
|
||||
description: str = ""
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
decorator that emits the agent status as processing while the function
|
||||
is running.
|
||||
|
||||
|
||||
Done via a try - final block to ensure the status is reset even if
|
||||
the function fails.
|
||||
"""
|
||||
|
||||
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
with ActiveAgent(self, fn):
|
||||
try:
|
||||
@@ -69,9 +70,9 @@ def set_processing(fn):
|
||||
# not sure why this happens
|
||||
# some concurrency error?
|
||||
log.error("error emitting agent status", exc=exc)
|
||||
|
||||
|
||||
wrapper.__name__ = fn.__name__
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -97,16 +98,14 @@ class Agent(ABC):
|
||||
def verbose_name(self):
|
||||
return self.agent_type.capitalize()
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
if not getattr(self.client, "enabled", True):
|
||||
return False
|
||||
|
||||
|
||||
if self.client and self.client.current_status in ["error", "warning"]:
|
||||
return False
|
||||
|
||||
|
||||
return self.client is not None
|
||||
|
||||
@property
|
||||
@@ -123,20 +122,20 @@ class Agent(ABC):
|
||||
# by default, agents are enabled, an agent class that
|
||||
# is disableable should override this property
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def disable(self):
|
||||
# by default, agents are enabled, an agent class that
|
||||
# is disableable should override this property to
|
||||
# is disableable should override this property to
|
||||
# disable the agent
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
# by default, agents do not have toggles to enable / disable
|
||||
# an agent class that is disableable should override this property
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
# by default, agents are not experimental, an agent class that
|
||||
@@ -153,85 +152,92 @@ class Agent(ABC):
|
||||
"requires_llm_client": cls.requires_llm_client,
|
||||
}
|
||||
actions = getattr(agent, "actions", None)
|
||||
|
||||
|
||||
if actions:
|
||||
config_options["actions"] = {k: v.model_dump() for k, v in actions.items()}
|
||||
else:
|
||||
config_options["actions"] = {}
|
||||
|
||||
|
||||
return config_options
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
if self.has_toggle and "enabled" in kwargs:
|
||||
self.is_enabled = kwargs.get("enabled", False)
|
||||
|
||||
|
||||
if not getattr(self, "actions", None):
|
||||
return
|
||||
|
||||
|
||||
for action_key, action in self.actions.items():
|
||||
|
||||
if not kwargs.get("actions"):
|
||||
continue
|
||||
|
||||
action.enabled = kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
|
||||
|
||||
|
||||
action.enabled = (
|
||||
kwargs.get("actions", {}).get(action_key, {}).get("enabled", False)
|
||||
)
|
||||
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
|
||||
for config_key, config in action.config.items():
|
||||
try:
|
||||
config.value = kwargs.get("actions", {}).get(action_key, {}).get("config", {}).get(config_key, {}).get("value", config.value)
|
||||
config.value = (
|
||||
kwargs.get("actions", {})
|
||||
.get(action_key, {})
|
||||
.get("config", {})
|
||||
.get(config_key, {})
|
||||
.get("value", config.value)
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
async def on_game_loop_start(self, event:GameLoopStartEvent):
|
||||
|
||||
|
||||
async def on_game_loop_start(self, event: GameLoopStartEvent):
|
||||
"""
|
||||
Finds all ActionConfigs that have a scope of "scene" and resets them to their default values
|
||||
"""
|
||||
|
||||
|
||||
if not getattr(self, "actions", None):
|
||||
return
|
||||
|
||||
|
||||
for _, action in self.actions.items():
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
|
||||
for _, config in action.config.items():
|
||||
if config.scope == "scene":
|
||||
# if default_value is None, just use the `type` of the current
|
||||
# if default_value is None, just use the `type` of the current
|
||||
# value
|
||||
if config.default_value is None:
|
||||
default_value = type(config.value)()
|
||||
else:
|
||||
default_value = config.default_value
|
||||
|
||||
log.debug("resetting config", config=config, default_value=default_value)
|
||||
|
||||
log.debug(
|
||||
"resetting config", config=config, default_value=default_value
|
||||
)
|
||||
config.value = default_value
|
||||
|
||||
|
||||
await self.emit_status()
|
||||
|
||||
|
||||
async def emit_status(self, processing: bool = None):
|
||||
|
||||
# should keep a count of processing requests, and when the
|
||||
# number is 0 status is "idle", if the number is greater than 0
|
||||
# status is "busy"
|
||||
#
|
||||
# increase / decrease based on value of `processing`
|
||||
|
||||
|
||||
if getattr(self, "processing", None) is None:
|
||||
self.processing = 0
|
||||
|
||||
|
||||
if not processing:
|
||||
self.processing -= 1
|
||||
self.processing = max(0, self.processing)
|
||||
else:
|
||||
self.processing += 1
|
||||
|
||||
|
||||
status = "busy" if self.processing > 0 else "idle"
|
||||
if not self.enabled:
|
||||
status = "disabled"
|
||||
|
||||
|
||||
emit(
|
||||
"agent_status",
|
||||
message=self.verbose_name or "",
|
||||
@@ -245,8 +251,9 @@ class Agent(ABC):
|
||||
|
||||
def connect(self, scene):
|
||||
self.scene = scene
|
||||
talemate.emit.async_signals.get("game_loop_start").connect(self.on_game_loop_start)
|
||||
|
||||
talemate.emit.async_signals.get("game_loop_start").connect(
|
||||
self.on_game_loop_start
|
||||
)
|
||||
|
||||
def clean_result(self, result):
|
||||
if "#" in result:
|
||||
@@ -291,23 +298,28 @@ class Agent(ABC):
|
||||
|
||||
current_memory_context.append(memory)
|
||||
return current_memory_context
|
||||
|
||||
|
||||
# LLM client related methods. These are called during or after the client
|
||||
# sends the prompt to the API.
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param:dict, kind:str, agent_function_name:str):
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
"""
|
||||
Injects prompt parameters before the client sends off the prompt
|
||||
Override as needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def allow_repetition_break(self, kind:str, agent_function_name:str, auto:bool=False):
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
"""
|
||||
Returns True if repetition breaking is allowed, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentEmission:
|
||||
agent: Agent
|
||||
agent: Agent
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""
|
||||
Code has been moved.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
import contextvars
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
@@ -9,25 +9,26 @@ __all__ = [
|
||||
|
||||
active_agent = contextvars.ContextVar("active_agent", default=None)
|
||||
|
||||
|
||||
class ActiveAgentContext(pydantic.BaseModel):
|
||||
agent: object
|
||||
fn: Callable
|
||||
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed=True
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def action(self):
|
||||
return self.fn.__name__
|
||||
|
||||
|
||||
class ActiveAgent:
|
||||
|
||||
def __init__(self, agent, fn):
|
||||
self.agent = ActiveAgentContext(agent=agent, fn=fn)
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.token = active_agent.set(self.agent)
|
||||
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
active_agent.reset(self.token)
|
||||
return False
|
||||
|
||||
@@ -1,40 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
import random
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.client as client
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
import structlog
|
||||
from talemate.client.context import (
|
||||
client_context_attribute,
|
||||
set_client_context_attribute,
|
||||
set_conversation_context_attribute,
|
||||
)
|
||||
from talemate.emit import emit
|
||||
import talemate.emit.async_signals
|
||||
from talemate.scene_message import CharacterMessage, DirectorMessage
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.client.context import set_conversation_context_attribute, client_context_attribute, set_client_context_attribute
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import CharacterMessage, DirectorMessage
|
||||
|
||||
from .base import Agent, AgentEmission, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character, Scene, Actor
|
||||
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agents.conversation")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConversationAgentEmission(AgentEmission):
|
||||
actor: Actor
|
||||
character: Character
|
||||
generation: list[str]
|
||||
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.conversation.before_generate",
|
||||
"agent.conversation.generated"
|
||||
"agent.conversation.before_generate", "agent.conversation.generated"
|
||||
)
|
||||
|
||||
|
||||
@register()
|
||||
class ConversationAgent(Agent):
|
||||
"""
|
||||
@@ -45,7 +53,7 @@ class ConversationAgent(Agent):
|
||||
|
||||
agent_type = "conversation"
|
||||
verbose_name = "Conversation"
|
||||
|
||||
|
||||
min_dialogue_length = 75
|
||||
|
||||
def __init__(
|
||||
@@ -60,28 +68,28 @@ class ConversationAgent(Agent):
|
||||
self.logging_enabled = logging_enabled
|
||||
self.logging_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
self.current_memory_context = None
|
||||
|
||||
|
||||
# several agents extend this class, but we only want to initialize
|
||||
# these actions for the conversation agent
|
||||
|
||||
|
||||
if self.agent_type != "conversation":
|
||||
return
|
||||
|
||||
|
||||
self.actions = {
|
||||
"generation_override": AgentAction(
|
||||
enabled = True,
|
||||
label = "Generation Override",
|
||||
description = "Override generation parameters",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Generation Override",
|
||||
description="Override generation parameters",
|
||||
config={
|
||||
"length": AgentActionConfig(
|
||||
type="number",
|
||||
label="Generation Length (tokens)",
|
||||
description="Maximum number of tokens to generate for a conversation response.",
|
||||
value=96,
|
||||
value=96,
|
||||
min=32,
|
||||
max=512,
|
||||
step=32,
|
||||
),#
|
||||
), #
|
||||
"instructions": AgentActionConfig(
|
||||
type="text",
|
||||
label="Instructions",
|
||||
@@ -96,24 +104,24 @@ class ConversationAgent(Agent):
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled = True,
|
||||
label = "Auto Break Repetition",
|
||||
description = "Will attempt to automatically break AI repetition.",
|
||||
enabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"natural_flow": AgentAction(
|
||||
enabled = True,
|
||||
label = "Natural Flow",
|
||||
description = "Will attempt to generate a more natural flow of conversation between multiple characters.",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Natural Flow",
|
||||
description="Will attempt to generate a more natural flow of conversation between multiple characters.",
|
||||
config={
|
||||
"max_auto_turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Max. Auto Turns",
|
||||
description="The maximum number of turns the AI is allowed to generate before it stops and waits for the player to respond.",
|
||||
value=4,
|
||||
value=4,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
@@ -122,31 +130,40 @@ class ConversationAgent(Agent):
|
||||
type="number",
|
||||
label="Max. Idle Turns",
|
||||
description="The maximum number of turns a character can go without speaking before they are considered overdue to speak.",
|
||||
value=8,
|
||||
value=8,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
"use_long_term_memory": AgentAction(
|
||||
enabled = True,
|
||||
label = "Long Term Memory",
|
||||
description = "Will augment the conversation prompt with long term memory.",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Long Term Memory",
|
||||
description="Will augment the conversation prompt with long term memory.",
|
||||
config={
|
||||
"retrieval_method": AgentActionConfig(
|
||||
type="text",
|
||||
label="Context Retrieval Method",
|
||||
description="How relevant context is retrieved from the long term memory.",
|
||||
value="direct",
|
||||
choices=[
|
||||
{"label": "Context queries based on recent dialogue (fast)", "value": "direct"},
|
||||
{"label": "Context queries generated by AI", "value": "queries"},
|
||||
{"label": "AI compiled question and answers (slow)", "value": "questions"},
|
||||
]
|
||||
{
|
||||
"label": "Context queries based on recent dialogue (fast)",
|
||||
"value": "direct",
|
||||
},
|
||||
{
|
||||
"label": "Context queries generated by AI",
|
||||
"value": "queries",
|
||||
},
|
||||
{
|
||||
"label": "AI compiled question and answers (slow)",
|
||||
"value": "questions",
|
||||
},
|
||||
],
|
||||
),
|
||||
}
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
def connect(self, scene):
|
||||
@@ -154,40 +171,37 @@ class ConversationAgent(Agent):
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
def last_spoken(self):
|
||||
|
||||
"""
|
||||
Returns the last time each character spoke
|
||||
"""
|
||||
|
||||
|
||||
last_turn = {}
|
||||
turns = 0
|
||||
character_names = self.scene.character_names
|
||||
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
|
||||
|
||||
|
||||
for idx in range(len(self.scene.history) - 1, -1, -1):
|
||||
|
||||
if isinstance(self.scene.history[idx], CharacterMessage):
|
||||
|
||||
if turns >= max_idle_turns:
|
||||
break
|
||||
|
||||
|
||||
character = self.scene.history[idx].character_name
|
||||
|
||||
|
||||
if character in character_names:
|
||||
last_turn[character] = turns
|
||||
character_names.remove(character)
|
||||
|
||||
|
||||
if not character_names:
|
||||
break
|
||||
|
||||
|
||||
turns += 1
|
||||
|
||||
|
||||
if character_names and turns >= max_idle_turns:
|
||||
for character in character_names:
|
||||
last_turn[character] = max_idle_turns
|
||||
last_turn[character] = max_idle_turns
|
||||
|
||||
return last_turn
|
||||
|
||||
|
||||
def repeated_speaker(self):
|
||||
"""
|
||||
Counts the amount of times the most recent speaker has spoken in a row
|
||||
@@ -203,125 +217,164 @@ class ConversationAgent(Agent):
|
||||
else:
|
||||
break
|
||||
return count
|
||||
|
||||
async def on_game_loop(self, event:GameLoopEvent):
|
||||
|
||||
async def on_game_loop(self, event: GameLoopEvent):
|
||||
await self.apply_natural_flow()
|
||||
|
||||
|
||||
async def apply_natural_flow(self, force: bool = False, npcs_only: bool = False):
|
||||
"""
|
||||
If the natural flow action is enabled, this will attempt to determine
|
||||
the ideal character to talk next.
|
||||
|
||||
|
||||
This will let the AI pick a character to talk to, but if the AI can't figure
|
||||
it out it will apply rules based on max_idle_turns and max_auto_turns.
|
||||
|
||||
|
||||
If all fails it will just pick a random character.
|
||||
|
||||
|
||||
Repetition is also taken into account, so if a character has spoken twice in a row
|
||||
they will not be picked again until someone else has spoken.
|
||||
"""
|
||||
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if not scene.auto_progress and not force:
|
||||
# we only apply natural flow if auto_progress is enabled
|
||||
return
|
||||
|
||||
|
||||
if self.actions["natural_flow"].enabled and len(scene.character_names) > 2:
|
||||
|
||||
# last time each character spoke (turns ago)
|
||||
max_idle_turns = self.actions["natural_flow"].config["max_idle_turns"].value
|
||||
max_auto_turns = self.actions["natural_flow"].config["max_auto_turns"].value
|
||||
last_turn = self.last_spoken()
|
||||
player_name = scene.get_player_character().name
|
||||
last_turn_player = last_turn.get(player_name, 0)
|
||||
|
||||
|
||||
if last_turn_player >= max_auto_turns and not npcs_only:
|
||||
self.scene.next_actor = scene.get_player_character().name
|
||||
log.debug("conversation_agent.natural_flow", next_actor="player", overdue=True, player_character=scene.get_player_character().name)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="player",
|
||||
overdue=True,
|
||||
player_character=scene.get_player_character().name,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
log.debug("conversation_agent.natural_flow", last_turn=last_turn)
|
||||
|
||||
|
||||
# determine random character to talk, this will be the fallback in case
|
||||
# the AI can't figure out who should talk next
|
||||
|
||||
|
||||
if scene.prev_actor:
|
||||
|
||||
# we dont want to talk to the same person twice in a row
|
||||
character_names = scene.character_names
|
||||
character_names.remove(scene.prev_actor)
|
||||
|
||||
|
||||
if npcs_only:
|
||||
character_names = [c for c in character_names if c != player_name]
|
||||
|
||||
|
||||
random_character_name = random.choice(character_names)
|
||||
else:
|
||||
character_names = scene.character_names
|
||||
character_names = scene.character_names
|
||||
# no one has talked yet, so we just pick a random character
|
||||
|
||||
|
||||
if npcs_only:
|
||||
character_names = [c for c in character_names if c != player_name]
|
||||
|
||||
|
||||
random_character_name = random.choice(scene.character_names)
|
||||
|
||||
overdue_characters = [character for character, turn in last_turn.items() if turn >= max_idle_turns]
|
||||
|
||||
|
||||
overdue_characters = [
|
||||
character
|
||||
for character, turn in last_turn.items()
|
||||
if turn >= max_idle_turns
|
||||
]
|
||||
|
||||
if npcs_only:
|
||||
overdue_characters = [c for c in overdue_characters if c != player_name]
|
||||
|
||||
|
||||
if overdue_characters and self.scene.history:
|
||||
# Pick a random character from the overdue characters
|
||||
scene.next_actor = random.choice(overdue_characters)
|
||||
elif scene.history:
|
||||
scene.next_actor = None
|
||||
|
||||
|
||||
# AI will attempt to figure out who should talk next
|
||||
next_actor = await self.select_talking_actor(character_names)
|
||||
next_actor = next_actor.strip().strip('"').strip(".")
|
||||
|
||||
|
||||
for character_name in scene.character_names:
|
||||
if next_actor.lower() in character_name.lower() or character_name.lower() in next_actor.lower():
|
||||
if (
|
||||
next_actor.lower() in character_name.lower()
|
||||
or character_name.lower() in next_actor.lower()
|
||||
):
|
||||
scene.next_actor = character_name
|
||||
break
|
||||
|
||||
|
||||
if not scene.next_actor:
|
||||
# AI couldn't figure out who should talk next, so we just pick a random character
|
||||
log.debug("conversation_agent.natural_flow", next_actor="random", random_character_name=random_character_name)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="random",
|
||||
random_character_name=random_character_name,
|
||||
)
|
||||
scene.next_actor = random_character_name
|
||||
else:
|
||||
log.debug("conversation_agent.natural_flow", next_actor="picked", ai_next_actor=scene.next_actor)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="picked",
|
||||
ai_next_actor=scene.next_actor,
|
||||
)
|
||||
else:
|
||||
# always start with main character (TODO: configurable?)
|
||||
player_character = scene.get_player_character()
|
||||
log.debug("conversation_agent.natural_flow", next_actor="main_character", main_character=player_character)
|
||||
scene.next_actor = player_character.name if player_character else random_character_name
|
||||
|
||||
scene.log.debug("conversation_agent.natural_flow", next_actor=scene.next_actor)
|
||||
log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="main_character",
|
||||
main_character=player_character,
|
||||
)
|
||||
scene.next_actor = (
|
||||
player_character.name if player_character else random_character_name
|
||||
)
|
||||
|
||||
scene.log.debug(
|
||||
"conversation_agent.natural_flow", next_actor=scene.next_actor
|
||||
)
|
||||
|
||||
|
||||
# same character cannot go thrice in a row, if this is happening, pick a random character that
|
||||
# isnt the same as the last character
|
||||
|
||||
if self.repeated_speaker() >= 2 and self.scene.prev_actor == self.scene.next_actor:
|
||||
scene.next_actor = random.choice([c for c in scene.character_names if c != scene.prev_actor])
|
||||
scene.log.debug("conversation_agent.natural_flow", next_actor="random (repeated safeguard)", random_character_name=scene.next_actor)
|
||||
|
||||
|
||||
if (
|
||||
self.repeated_speaker() >= 2
|
||||
and self.scene.prev_actor == self.scene.next_actor
|
||||
):
|
||||
scene.next_actor = random.choice(
|
||||
[c for c in scene.character_names if c != scene.prev_actor]
|
||||
)
|
||||
scene.log.debug(
|
||||
"conversation_agent.natural_flow",
|
||||
next_actor="random (repeated safeguard)",
|
||||
random_character_name=scene.next_actor,
|
||||
)
|
||||
|
||||
else:
|
||||
scene.next_actor = None
|
||||
|
||||
|
||||
@set_processing
|
||||
async def select_talking_actor(self, character_names: list[str]=None):
|
||||
result = await Prompt.request("conversation.select-talking-actor", self.client, "conversation_select_talking_actor", vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character_names": character_names or self.scene.character_names,
|
||||
"character_names_formatted": ", ".join(character_names or self.scene.character_names),
|
||||
})
|
||||
|
||||
async def select_talking_actor(self, character_names: list[str] = None):
|
||||
result = await Prompt.request(
|
||||
"conversation.select-talking-actor",
|
||||
self.client,
|
||||
"conversation_select_talking_actor",
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character_names": character_names or self.scene.character_names,
|
||||
"character_names_formatted": ", ".join(
|
||||
character_names or self.scene.character_names
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def build_prompt_default(
|
||||
self,
|
||||
@@ -335,17 +388,17 @@ class ConversationAgent(Agent):
|
||||
# we subtract 200 to account for the response
|
||||
|
||||
scene = character.actor.scene
|
||||
|
||||
|
||||
total_token_budget = self.client.max_token_length - 200
|
||||
scene_and_dialogue_budget = total_token_budget - 500
|
||||
long_term_memory_budget = min(int(total_token_budget * 0.05), 200)
|
||||
|
||||
|
||||
scene_and_dialogue = scene.context_history(
|
||||
budget=scene_and_dialogue_budget,
|
||||
budget=scene_and_dialogue_budget,
|
||||
keep_director=True,
|
||||
sections=False,
|
||||
)
|
||||
|
||||
|
||||
memory = await self.build_prompt_default_memory(character)
|
||||
|
||||
main_character = scene.main_character.character
|
||||
@@ -360,36 +413,39 @@ class ConversationAgent(Agent):
|
||||
)
|
||||
else:
|
||||
formatted_names = character_names[0] if character_names else ""
|
||||
|
||||
|
||||
try:
|
||||
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
|
||||
except IndexError:
|
||||
director_message = False
|
||||
|
||||
|
||||
extra_instructions = ""
|
||||
if self.actions["generation_override"].enabled:
|
||||
extra_instructions = self.actions["generation_override"].config["instructions"].value
|
||||
extra_instructions = (
|
||||
self.actions["generation_override"].config["instructions"].value
|
||||
)
|
||||
|
||||
prompt = Prompt.get(
|
||||
"conversation.dialogue",
|
||||
vars={
|
||||
"scene": scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene_and_dialogue_budget": scene_and_dialogue_budget,
|
||||
"scene_and_dialogue": scene_and_dialogue,
|
||||
"memory": memory,
|
||||
"characters": list(scene.get_characters()),
|
||||
"main_character": main_character,
|
||||
"formatted_names": formatted_names,
|
||||
"talking_character": character,
|
||||
"partial_message": char_message,
|
||||
"director_message": director_message,
|
||||
"extra_instructions": extra_instructions,
|
||||
},
|
||||
)
|
||||
|
||||
prompt = Prompt.get("conversation.dialogue", vars={
|
||||
"scene": scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene_and_dialogue_budget": scene_and_dialogue_budget,
|
||||
"scene_and_dialogue": scene_and_dialogue,
|
||||
"memory": memory,
|
||||
"characters": list(scene.get_characters()),
|
||||
"main_character": main_character,
|
||||
"formatted_names": formatted_names,
|
||||
"talking_character": character,
|
||||
"partial_message": char_message,
|
||||
"director_message": director_message,
|
||||
"extra_instructions": extra_instructions,
|
||||
})
|
||||
|
||||
return str(prompt)
|
||||
|
||||
async def build_prompt_default_memory(
|
||||
self, character: Character
|
||||
):
|
||||
|
||||
async def build_prompt_default_memory(self, character: Character):
|
||||
"""
|
||||
Builds long term memory for the conversation prompt
|
||||
|
||||
@@ -404,39 +460,56 @@ class ConversationAgent(Agent):
|
||||
if not self.actions["use_long_term_memory"].enabled:
|
||||
return []
|
||||
|
||||
|
||||
if self.current_memory_context:
|
||||
return self.current_memory_context
|
||||
|
||||
self.current_memory_context = ""
|
||||
retrieval_method = self.actions["use_long_term_memory"].config["retrieval_method"].value
|
||||
|
||||
retrieval_method = (
|
||||
self.actions["use_long_term_memory"].config["retrieval_method"].value
|
||||
)
|
||||
|
||||
if retrieval_method != "direct":
|
||||
|
||||
world_state = instance.get_agent("world_state")
|
||||
history = self.scene.context_history(min_dialogue=3, max_dialogue=15, keep_director=False, sections=False, add_archieved_history=False)
|
||||
history = self.scene.context_history(
|
||||
min_dialogue=3,
|
||||
max_dialogue=15,
|
||||
keep_director=False,
|
||||
sections=False,
|
||||
add_archieved_history=False,
|
||||
)
|
||||
text = "\n".join(history)
|
||||
log.debug("conversation_agent.build_prompt_default_memory", direct=False, version=retrieval_method)
|
||||
|
||||
log.debug(
|
||||
"conversation_agent.build_prompt_default_memory",
|
||||
direct=False,
|
||||
version=retrieval_method,
|
||||
)
|
||||
|
||||
if retrieval_method == "questions":
|
||||
self.current_memory_context = (await world_state.analyze_text_and_extract_context(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
)).split("\n")
|
||||
self.current_memory_context = (
|
||||
await world_state.analyze_text_and_extract_context(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
)
|
||||
).split("\n")
|
||||
elif retrieval_method == "queries":
|
||||
self.current_memory_context = await world_state.analyze_text_and_extract_context_via_queries(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
self.current_memory_context = (
|
||||
await world_state.analyze_text_and_extract_context_via_queries(
|
||||
text, f"continue the conversation as {character.name}"
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
history = list(map(str, self.scene.collect_messages(max_iterations=3)))
|
||||
log.debug("conversation_agent.build_prompt_default_memory", history=history, direct=True)
|
||||
log.debug(
|
||||
"conversation_agent.build_prompt_default_memory",
|
||||
history=history,
|
||||
direct=True,
|
||||
)
|
||||
memory = instance.get_agent("memory")
|
||||
|
||||
|
||||
context = await memory.multi_query(history, max_tokens=500, iterate=5)
|
||||
|
||||
|
||||
self.current_memory_context = context
|
||||
|
||||
|
||||
return self.current_memory_context
|
||||
|
||||
async def build_prompt(self, character, char_message: str = ""):
|
||||
@@ -445,10 +518,9 @@ class ConversationAgent(Agent):
|
||||
return await fn(character, char_message=char_message)
|
||||
|
||||
def clean_result(self, result, character):
|
||||
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
|
||||
|
||||
result = result.replace(" :", ":")
|
||||
result = result.replace("[", "*").replace("]", "*")
|
||||
result = result.replace("(", "*").replace(")", "*")
|
||||
@@ -459,15 +531,19 @@ class ConversationAgent(Agent):
|
||||
def set_generation_overrides(self):
|
||||
if not self.actions["generation_override"].enabled:
|
||||
return
|
||||
|
||||
set_conversation_context_attribute("length", self.actions["generation_override"].config["length"].value)
|
||||
|
||||
|
||||
set_conversation_context_attribute(
|
||||
"length", self.actions["generation_override"].config["length"].value
|
||||
)
|
||||
|
||||
if self.actions["generation_override"].config["jiggle"].value > 0.0:
|
||||
nuke_repetition = client_context_attribute("nuke_repetition")
|
||||
if nuke_repetition == 0.0:
|
||||
# we only apply the agent override if some other mechanism isn't already
|
||||
# setting the nuke_repetition value
|
||||
nuke_repetition = self.actions["generation_override"].config["jiggle"].value
|
||||
nuke_repetition = (
|
||||
self.actions["generation_override"].config["jiggle"].value
|
||||
)
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
|
||||
@set_processing
|
||||
@@ -479,10 +555,14 @@ class ConversationAgent(Agent):
|
||||
self.current_memory_context = None
|
||||
|
||||
character = actor.character
|
||||
|
||||
emission = ConversationAgentEmission(agent=self, generation="", actor=actor, character=character)
|
||||
await talemate.emit.async_signals.get("agent.conversation.before_generate").send(emission)
|
||||
|
||||
|
||||
emission = ConversationAgentEmission(
|
||||
agent=self, generation="", actor=actor, character=character
|
||||
)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.conversation.before_generate"
|
||||
).send(emission)
|
||||
|
||||
self.set_generation_overrides()
|
||||
|
||||
result = await self.client.send_prompt(await self.build_prompt(character))
|
||||
@@ -505,7 +585,7 @@ class ConversationAgent(Agent):
|
||||
|
||||
result = self.clean_result(result, character)
|
||||
|
||||
total_result += " "+result
|
||||
total_result += " " + result
|
||||
|
||||
if len(total_result) == 0 and max_loops < 10:
|
||||
max_loops += 1
|
||||
@@ -529,7 +609,7 @@ class ConversationAgent(Agent):
|
||||
|
||||
# Removes partial sentence at the end
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
|
||||
# Remove "{character.name}:" - all occurences
|
||||
total_result = total_result.replace(f"{character.name}:", "")
|
||||
|
||||
@@ -548,13 +628,17 @@ class ConversationAgent(Agent):
|
||||
)
|
||||
|
||||
response_message = util.parse_messages_from_str(total_result, [character.name])
|
||||
|
||||
log.info("conversation agent", result=response_message)
|
||||
|
||||
emission = ConversationAgentEmission(agent=self, generation=response_message, actor=actor, character=character)
|
||||
await talemate.emit.async_signals.get("agent.conversation.generated").send(emission)
|
||||
|
||||
#log.info("conversation agent", generation=emission.generation)
|
||||
log.info("conversation agent", result=response_message)
|
||||
|
||||
emission = ConversationAgentEmission(
|
||||
agent=self, generation=response_message, actor=actor, character=character
|
||||
)
|
||||
await talemate.emit.async_signals.get("agent.conversation.generated").send(
|
||||
emission
|
||||
)
|
||||
|
||||
# log.info("conversation agent", generation=emission.generation)
|
||||
|
||||
messages = [CharacterMessage(message) for message in emission.generation]
|
||||
|
||||
@@ -563,15 +647,17 @@ class ConversationAgent(Agent):
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def allow_repetition_break(self, kind: str, agent_function_name: str, auto: bool = False):
|
||||
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
|
||||
return agent_function_name == "converse"
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
|
||||
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += ['[']
|
||||
prompt_param["extra_stopping_strings"] += ["["]
|
||||
|
||||
@@ -3,22 +3,23 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
|
||||
import talemate.client as client
|
||||
from talemate.agents.base import Agent, set_processing
|
||||
from talemate.agents.registry import register
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
import talemate.client as client
|
||||
|
||||
from .character import CharacterCreatorMixin
|
||||
from .scenario import ScenarioCreatorMixin
|
||||
|
||||
|
||||
@register()
|
||||
class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
|
||||
"""
|
||||
Creates characters and scenarios and other fun stuff!
|
||||
"""
|
||||
|
||||
|
||||
agent_type = "creator"
|
||||
verbose_name = "Creator"
|
||||
|
||||
@@ -78,12 +79,14 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
# Remove duplicates while preserving the order for list type keys
|
||||
for key, value in merged_data.items():
|
||||
if isinstance(value, list):
|
||||
merged_data[key] = [x for i, x in enumerate(value) if x not in value[:i]]
|
||||
merged_data[key] = [
|
||||
x for i, x in enumerate(value) if x not in value[:i]
|
||||
]
|
||||
|
||||
merged_data["context"] = context
|
||||
|
||||
return merged_data
|
||||
|
||||
|
||||
def load_templates_old(self, names: list, template_type: str = "character") -> dict:
|
||||
"""
|
||||
Loads multiple character creation templates from ./templates/character and merges them in order.
|
||||
@@ -128,8 +131,10 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
if "context" in template_data["instructions"]:
|
||||
context = template_data["instructions"]["context"]
|
||||
|
||||
merged_instructions[name]["questions"] = [q[0] for q in template_data.get("questions", [])]
|
||||
|
||||
merged_instructions[name]["questions"] = [
|
||||
q[0] for q in template_data.get("questions", [])
|
||||
]
|
||||
|
||||
# Remove duplicates while preserving the order
|
||||
merged_template = [
|
||||
@@ -158,24 +163,33 @@ class CreatorAgent(CharacterCreatorMixin, ScenarioCreatorMixin, Agent):
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_json_list(
|
||||
self,
|
||||
text:str,
|
||||
count:int=20,
|
||||
first_item:str=None,
|
||||
text: str,
|
||||
count: int = 20,
|
||||
first_item: str = None,
|
||||
):
|
||||
_, json_list = await Prompt.request(f"creator.generate-json-list", self.client, "create", vars={
|
||||
"text": text,
|
||||
"first_item": first_item,
|
||||
"count": count,
|
||||
})
|
||||
return json_list.get("items",[])
|
||||
|
||||
_, json_list = await Prompt.request(
|
||||
f"creator.generate-json-list",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"text": text,
|
||||
"first_item": first_item,
|
||||
"count": count,
|
||||
},
|
||||
)
|
||||
return json_list.get("items", [])
|
||||
|
||||
@set_processing
|
||||
async def generate_title(self, text:str):
|
||||
title = await Prompt.request(f"creator.generate-title", self.client, "create_short", vars={
|
||||
"text": text,
|
||||
})
|
||||
return title
|
||||
async def generate_title(self, text: str):
|
||||
title = await Prompt.request(
|
||||
f"creator.generate-title",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
"text": text,
|
||||
},
|
||||
)
|
||||
return title
|
||||
|
||||
@@ -1,42 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import random
|
||||
import structlog
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt, LoopedPrompt
|
||||
from talemate.exceptions import LLMAccuracyError
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.emit import emit
|
||||
from talemate.exceptions import LLMAccuracyError
|
||||
from talemate.prompts import LoopedPrompt, Prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
log = structlog.get_logger("talemate.agents.creator.character")
|
||||
|
||||
def validate(k,v):
|
||||
|
||||
def validate(k, v):
|
||||
if k and k.lower() == "gender":
|
||||
return v.lower().strip()
|
||||
if k and k.lower() == "age":
|
||||
try:
|
||||
return int(v.split("\n")[0].strip())
|
||||
except (ValueError, TypeError):
|
||||
raise LLMAccuracyError("Was unable to get a valid age from the response", model_name=None)
|
||||
|
||||
raise LLMAccuracyError(
|
||||
"Was unable to get a valid age from the response", model_name=None
|
||||
)
|
||||
|
||||
return v.strip().strip("\n")
|
||||
|
||||
DEFAULT_CONTENT_CONTEXT="a fun and engaging adventure aimed at an adult audience."
|
||||
|
||||
DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audience."
|
||||
|
||||
|
||||
class CharacterCreatorMixin:
|
||||
"""
|
||||
Adds character creation functionality to the creator agent
|
||||
"""
|
||||
|
||||
|
||||
## NEW
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_attributes(
|
||||
self,
|
||||
@@ -48,8 +54,6 @@ class CharacterCreatorMixin:
|
||||
custom_attributes: dict[str, str] = dict(),
|
||||
predefined_attributes: dict[str, str] = dict(),
|
||||
):
|
||||
|
||||
|
||||
def spice(prompt, spices):
|
||||
# generate number from 0 to 1 and if its smaller than use_spice
|
||||
# select a random spice from the list and return it formatted
|
||||
@@ -57,69 +61,74 @@ class CharacterCreatorMixin:
|
||||
if random.random() < use_spice:
|
||||
spice = random.choice(spices)
|
||||
return prompt.format(spice=spice)
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
# drop any empty attributes from predefined_attributes
|
||||
|
||||
predefined_attributes = {k:v for k,v in predefined_attributes.items() if v}
|
||||
|
||||
prompt = Prompt.get(f"creator.character-attributes-{template}", vars={
|
||||
"character_prompt": character_prompt,
|
||||
"template": template,
|
||||
"spice": spice,
|
||||
"content_context": content_context,
|
||||
"custom_attributes": custom_attributes,
|
||||
"character_sheet": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=attribute_callback,
|
||||
generated=predefined_attributes,
|
||||
),
|
||||
})
|
||||
|
||||
predefined_attributes = {k: v for k, v in predefined_attributes.items() if v}
|
||||
|
||||
prompt = Prompt.get(
|
||||
f"creator.character-attributes-{template}",
|
||||
vars={
|
||||
"character_prompt": character_prompt,
|
||||
"template": template,
|
||||
"spice": spice,
|
||||
"content_context": content_context,
|
||||
"custom_attributes": custom_attributes,
|
||||
"character_sheet": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=attribute_callback,
|
||||
generated=predefined_attributes,
|
||||
),
|
||||
},
|
||||
)
|
||||
await prompt.loop(self.client, "character_sheet", kind="create_concise")
|
||||
|
||||
|
||||
return prompt.vars["character_sheet"].generated
|
||||
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_description(
|
||||
self,
|
||||
character:Character,
|
||||
self,
|
||||
character: Character,
|
||||
content_context: str = DEFAULT_CONTENT_CONTEXT,
|
||||
):
|
||||
|
||||
description = await Prompt.request(f"creator.character-description", self.client, "create", vars={
|
||||
"character": character,
|
||||
"content_context": content_context,
|
||||
})
|
||||
|
||||
description = await Prompt.request(
|
||||
f"creator.character-description",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
"content_context": content_context,
|
||||
},
|
||||
)
|
||||
|
||||
return description.strip()
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_details(
|
||||
self,
|
||||
self,
|
||||
character: Character,
|
||||
template: str,
|
||||
detail_callback: Callable = lambda question, answer: None,
|
||||
questions: list[str] = None,
|
||||
content_context: str = DEFAULT_CONTENT_CONTEXT,
|
||||
):
|
||||
prompt = Prompt.get(f"creator.character-details-{template}", vars={
|
||||
"character_details": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=detail_callback,
|
||||
),
|
||||
"template": template,
|
||||
"content_context": content_context,
|
||||
"character": character,
|
||||
"custom_questions": questions or [],
|
||||
})
|
||||
prompt = Prompt.get(
|
||||
f"creator.character-details-{template}",
|
||||
vars={
|
||||
"character_details": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=detail_callback,
|
||||
),
|
||||
"template": template,
|
||||
"content_context": content_context,
|
||||
"character": character,
|
||||
"custom_questions": questions or [],
|
||||
},
|
||||
)
|
||||
await prompt.loop(self.client, "character_details", kind="create_concise")
|
||||
return prompt.vars["character_details"].generated
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_character_example_dialogue(
|
||||
self,
|
||||
@@ -131,97 +140,116 @@ class CharacterCreatorMixin:
|
||||
example_callback: Callable = lambda example: None,
|
||||
rules_callback: Callable = lambda rules: None,
|
||||
):
|
||||
|
||||
dialogue_rules = await Prompt.request(f"creator.character-dialogue-rules", self.client, "create", vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
})
|
||||
|
||||
dialogue_rules = await Prompt.request(
|
||||
f"creator.character-dialogue-rules",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
},
|
||||
)
|
||||
|
||||
log.info("dialogue_rules", dialogue_rules=dialogue_rules)
|
||||
|
||||
|
||||
if rules_callback:
|
||||
rules_callback(dialogue_rules)
|
||||
|
||||
example_dialogue_prompt = Prompt.get(f"creator.character-example-dialogue-{template}", vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
"dialogue_rules": dialogue_rules,
|
||||
"generated_examples": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=example_callback,
|
||||
),
|
||||
})
|
||||
|
||||
await example_dialogue_prompt.loop(self.client, "generated_examples", kind="create")
|
||||
|
||||
|
||||
example_dialogue_prompt = Prompt.get(
|
||||
f"creator.character-example-dialogue-{template}",
|
||||
vars={
|
||||
"guide": guide,
|
||||
"character": character,
|
||||
"examples": examples or [],
|
||||
"content_context": content_context,
|
||||
"dialogue_rules": dialogue_rules,
|
||||
"generated_examples": LoopedPrompt(
|
||||
validate_value=validate,
|
||||
on_update=example_callback,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
await example_dialogue_prompt.loop(
|
||||
self.client, "generated_examples", kind="create"
|
||||
)
|
||||
|
||||
return example_dialogue_prompt.vars["generated_examples"].generated
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_content_context_for_character(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
|
||||
content_context = await Prompt.request(f"creator.determine-content-context", self.client, "create", vars={
|
||||
"character": character,
|
||||
})
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
},
|
||||
)
|
||||
return content_context.strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_character_attributes(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
|
||||
attributes = await Prompt.request(f"creator.determine-character-attributes", self.client, "analyze_long", vars={
|
||||
"character": character,
|
||||
})
|
||||
attributes = await Prompt.request(
|
||||
f"creator.determine-character-attributes",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
"character": character,
|
||||
},
|
||||
)
|
||||
return attributes
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_character_description(
|
||||
self,
|
||||
character: Character,
|
||||
text:str=""
|
||||
self, character: Character, text: str = ""
|
||||
):
|
||||
|
||||
description = await Prompt.request(f"creator.determine-character-description", self.client, "create", vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"text": text,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
})
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-character-description",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"text": text,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
return description.strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_character_goals(
|
||||
self,
|
||||
character: Character,
|
||||
goal_instructions: str,
|
||||
):
|
||||
|
||||
goals = await Prompt.request(f"creator.determine-character-goals", self.client, "create", vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"goal_instructions": goal_instructions,
|
||||
"npc_name": character.name,
|
||||
"player_name": self.scene.get_player_character().name,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
})
|
||||
|
||||
goals = await Prompt.request(
|
||||
f"creator.determine-character-goals",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"goal_instructions": goal_instructions,
|
||||
"npc_name": character.name,
|
||||
"player_name": self.scene.get_player_character().name,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("determine_character_goals", goals=goals, character=character)
|
||||
await character.set_detail("goals", goals.strip())
|
||||
|
||||
|
||||
return goals.strip()
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_character_from_text(
|
||||
self,
|
||||
@@ -229,11 +257,8 @@ class CharacterCreatorMixin:
|
||||
template: str,
|
||||
content_context: str = DEFAULT_CONTENT_CONTEXT,
|
||||
):
|
||||
|
||||
base_attributes = await self.create_character_attributes(
|
||||
character_prompt=text,
|
||||
template=template,
|
||||
content_context=content_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from talemate.emit import emit, wait_for_input_yesno
|
||||
import re
|
||||
import random
|
||||
import re
|
||||
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.emit import emit, wait_for_input_yesno
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
|
||||
class ScenarioCreatorMixin:
|
||||
|
||||
|
||||
"""
|
||||
Adds scenario creation functionality to the creator agent
|
||||
"""
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_scene_description(
|
||||
self,
|
||||
prompt:str,
|
||||
content_context:str,
|
||||
prompt: str,
|
||||
content_context: str,
|
||||
):
|
||||
|
||||
"""
|
||||
Creates a new scene.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
prompt (str): The prompt to use to create the scene.
|
||||
|
||||
|
||||
content_context (str): The content context to use for the scene.
|
||||
|
||||
|
||||
callback (callable): A callback to call when the scene has been created.
|
||||
"""
|
||||
scene = self.scene
|
||||
|
||||
|
||||
description = await Prompt.request(
|
||||
"creator.scenario-description",
|
||||
self.client,
|
||||
@@ -40,35 +40,32 @@ class ScenarioCreatorMixin:
|
||||
"content_context": content_context,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": scene,
|
||||
}
|
||||
},
|
||||
)
|
||||
description = description.strip()
|
||||
|
||||
|
||||
return description
|
||||
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_scene_name(
|
||||
self,
|
||||
prompt:str,
|
||||
content_context:str,
|
||||
description:str,
|
||||
prompt: str,
|
||||
content_context: str,
|
||||
description: str,
|
||||
):
|
||||
|
||||
"""
|
||||
Generates a scene name.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
prompt (str): The prompt to use to generate the scene name.
|
||||
|
||||
|
||||
content_context (str): The content context to use for the scene.
|
||||
|
||||
|
||||
description (str): The description of the scene.
|
||||
"""
|
||||
scene = self.scene
|
||||
|
||||
|
||||
name = await Prompt.request(
|
||||
"creator.scenario-name",
|
||||
self.client,
|
||||
@@ -78,37 +75,35 @@ class ScenarioCreatorMixin:
|
||||
"content_context": content_context,
|
||||
"description": description,
|
||||
"scene": scene,
|
||||
}
|
||||
},
|
||||
)
|
||||
name = name.strip().strip('.!').replace('"','')
|
||||
name = name.strip().strip(".!").replace('"', "")
|
||||
return name
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def create_scene_intro(
|
||||
self,
|
||||
prompt:str,
|
||||
content_context:str,
|
||||
description:str,
|
||||
name:str,
|
||||
prompt: str,
|
||||
content_context: str,
|
||||
description: str,
|
||||
name: str,
|
||||
):
|
||||
|
||||
"""
|
||||
Generates a scene introduction.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
prompt (str): The prompt to use to generate the scene introduction.
|
||||
|
||||
|
||||
content_context (str): The content context to use for the scene.
|
||||
|
||||
|
||||
description (str): The description of the scene.
|
||||
|
||||
|
||||
name (str): The name of the scene.
|
||||
"""
|
||||
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
intro = await Prompt.request(
|
||||
"creator.scenario-intro",
|
||||
self.client,
|
||||
@@ -119,17 +114,19 @@ class ScenarioCreatorMixin:
|
||||
"description": description,
|
||||
"name": name,
|
||||
"scene": scene,
|
||||
}
|
||||
},
|
||||
)
|
||||
intro = intro.strip()
|
||||
return intro
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_scenario_description(
|
||||
self,
|
||||
text:str
|
||||
):
|
||||
description = await Prompt.request(f"creator.determine-scenario-description", self.client, "analyze_long", vars={
|
||||
"text": text,
|
||||
})
|
||||
return description
|
||||
async def determine_scenario_description(self, text: str):
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-scenario-description",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
"text": text,
|
||||
},
|
||||
)
|
||||
return description
|
||||
|
||||
@@ -1,200 +1,251 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import random
|
||||
import structlog
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import talemate.util as util
|
||||
from talemate.emit import wait_for_input, emit
|
||||
import talemate.emit.async_signals
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import NarratorMessage, DirectorMessage
|
||||
from talemate.automated_action import AutomatedAction
|
||||
import structlog
|
||||
|
||||
import talemate.automated_action as automated_action
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from .registry import register
|
||||
from .base import set_processing, AgentAction, AgentActionConfig, Agent
|
||||
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.automated_action import AutomatedAction
|
||||
from talemate.emit import emit, wait_for_input
|
||||
from talemate.events import GameLoopActorIterEvent, GameLoopStartEvent, SceneStateEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, NarratorMessage
|
||||
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Actor, Character, Player, Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agent.director")
|
||||
|
||||
|
||||
@register()
|
||||
class DirectorAgent(Agent):
|
||||
agent_type = "director"
|
||||
verbose_name = "Director"
|
||||
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.is_enabled = True
|
||||
self.client = client
|
||||
self.next_direct_character = {}
|
||||
self.next_direct_scene = 0
|
||||
self.actions = {
|
||||
"direct": AgentAction(enabled=True, label="Direct", description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).", config={
|
||||
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before directing the sceen", value=5, min=1, max=100, step=1),
|
||||
"direct_scene": AgentActionConfig(type="bool", label="Direct Scene", description="If enabled, the scene will be directed through narration", value=True),
|
||||
"direct_actors": AgentActionConfig(type="bool", label="Direct Actors", description="If enabled, direction will be given to actors based on their goals.", value=True),
|
||||
}),
|
||||
"direct": AgentAction(
|
||||
enabled=True,
|
||||
label="Direct",
|
||||
description="Will attempt to direct the scene. Runs automatically after AI dialogue (n turns).",
|
||||
config={
|
||||
"turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Turns",
|
||||
description="Number of turns to wait before directing the sceen",
|
||||
value=5,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
),
|
||||
"direct_scene": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Direct Scene",
|
||||
description="If enabled, the scene will be directed through narration",
|
||||
value=True,
|
||||
),
|
||||
"direct_actors": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Direct Actors",
|
||||
description="If enabled, direction will be given to actors based on their goals.",
|
||||
value=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(self.on_conversation_before_generate)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_player_dialog)
|
||||
talemate.emit.async_signals.get("agent.conversation.before_generate").connect(
|
||||
self.on_conversation_before_generate
|
||||
)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
|
||||
self.on_player_dialog
|
||||
)
|
||||
talemate.emit.async_signals.get("scene_init").connect(self.on_scene_init)
|
||||
|
||||
|
||||
async def on_scene_init(self, event: SceneStateEvent):
|
||||
"""
|
||||
If game state instructions specify to be run at the start of the game loop
|
||||
we will run them here.
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
if self.scene.game_state.has_scene_instructions:
|
||||
self.is_enabled = True
|
||||
log.warning("on_scene_init - enabling director", scene=self.scene)
|
||||
self.is_enabled = True
|
||||
log.warning("on_scene_init - enabling director", scene=self.scene)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
return
|
||||
|
||||
if not self.scene.game_state.ops.run_on_start:
|
||||
return
|
||||
|
||||
|
||||
log.info("on_game_loop_start - running game state instructions")
|
||||
await self.run_gamestate_instructions()
|
||||
|
||||
async def on_conversation_before_generate(self, event:ConversationAgentEmission):
|
||||
|
||||
async def on_conversation_before_generate(self, event: ConversationAgentEmission):
|
||||
log.info("on_conversation_before_generate", director_enabled=self.enabled)
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
await self.direct(event.character)
|
||||
|
||||
async def on_player_dialog(self, event:GameLoopActorIterEvent):
|
||||
|
||||
|
||||
async def on_player_dialog(self, event: GameLoopActorIterEvent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
return
|
||||
|
||||
if not event.actor.character.is_player:
|
||||
return
|
||||
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug("director.on_player_dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
|
||||
log.debug(
|
||||
"director.on_player_dialog",
|
||||
skip=True,
|
||||
had_passive_narration=event.game_loop.had_passive_narration,
|
||||
)
|
||||
return
|
||||
|
||||
event.game_loop.had_passive_narration = await self.direct(None)
|
||||
|
||||
|
||||
async def direct(self, character: Character) -> bool:
|
||||
|
||||
if not self.actions["direct"].enabled:
|
||||
return False
|
||||
|
||||
|
||||
if character:
|
||||
|
||||
if not self.actions["direct"].config["direct_actors"].value:
|
||||
log.info("direct", skip=True, reason="direct_actors disabled", character=character)
|
||||
log.info(
|
||||
"direct",
|
||||
skip=True,
|
||||
reason="direct_actors disabled",
|
||||
character=character,
|
||||
)
|
||||
return False
|
||||
|
||||
# character direction, see if there are character goals
|
||||
|
||||
# character direction, see if there are character goals
|
||||
# defined
|
||||
character_goals = character.get_detail("goals")
|
||||
if not character_goals:
|
||||
log.info("direct", skip=True, reason="no goals", character=character)
|
||||
return False
|
||||
|
||||
|
||||
next_direct = self.next_direct_character.get(character.name, 0)
|
||||
|
||||
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
|
||||
log.info("direct", skip=True, next_direct=next_direct, character=character)
|
||||
|
||||
if (
|
||||
next_direct % self.actions["direct"].config["turns"].value != 0
|
||||
or next_direct == 0
|
||||
):
|
||||
log.info(
|
||||
"direct", skip=True, next_direct=next_direct, character=character
|
||||
)
|
||||
self.next_direct_character[character.name] = next_direct + 1
|
||||
return False
|
||||
|
||||
|
||||
self.next_direct_character[character.name] = 0
|
||||
await self.direct_scene(character, character_goals)
|
||||
return True
|
||||
else:
|
||||
|
||||
if not self.actions["direct"].config["direct_scene"].value:
|
||||
log.info("direct", skip=True, reason="direct_scene disabled")
|
||||
return False
|
||||
|
||||
|
||||
# no character, see if there are NPC characters at all
|
||||
# if not we always want to direct narration
|
||||
always_direct = (not self.scene.npc_character_names)
|
||||
|
||||
always_direct = not self.scene.npc_character_names
|
||||
|
||||
next_direct = self.next_direct_scene
|
||||
|
||||
if next_direct % self.actions["direct"].config["turns"].value != 0 or next_direct == 0:
|
||||
|
||||
if (
|
||||
next_direct % self.actions["direct"].config["turns"].value != 0
|
||||
or next_direct == 0
|
||||
):
|
||||
if not always_direct:
|
||||
log.info("direct", skip=True, next_direct=next_direct)
|
||||
self.next_direct_scene += 1
|
||||
return False
|
||||
|
||||
|
||||
self.next_direct_scene = 0
|
||||
await self.direct_scene(None, None)
|
||||
return True
|
||||
|
||||
|
||||
@set_processing
|
||||
async def run_gamestate_instructions(self):
|
||||
"""
|
||||
Run game state instructions, if they exist.
|
||||
"""
|
||||
|
||||
|
||||
if not self.scene.game_state.has_scene_instructions:
|
||||
return
|
||||
|
||||
|
||||
await self.direct_scene(None, None)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def direct_scene(self, character: Character, prompt:str):
|
||||
|
||||
async def direct_scene(self, character: Character, prompt: str):
|
||||
if not character and self.scene.game_state.game_won:
|
||||
# we are not directing a character, and the game has been won
|
||||
# so we don't need to direct the scene any further
|
||||
return
|
||||
|
||||
|
||||
if character:
|
||||
|
||||
# direct a character
|
||||
|
||||
response = await Prompt.request("director.direct-character", self.client, "director", vars={
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
"prompt": prompt,
|
||||
"character": character,
|
||||
"player_character": self.scene.get_player_character(),
|
||||
"game_state": self.scene.game_state,
|
||||
})
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"director.direct-character",
|
||||
self.client,
|
||||
"director",
|
||||
vars={
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
"prompt": prompt,
|
||||
"character": character,
|
||||
"player_character": self.scene.get_player_character(),
|
||||
"game_state": self.scene.game_state,
|
||||
},
|
||||
)
|
||||
|
||||
if "#" in response:
|
||||
response = response.split("#")[0]
|
||||
|
||||
log.info("direct_character", character=character, prompt=prompt, response=response)
|
||||
|
||||
|
||||
log.info(
|
||||
"direct_character",
|
||||
character=character,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
)
|
||||
|
||||
response = response.strip().split("\n")[0].strip()
|
||||
#response += f" (current story goal: {prompt})"
|
||||
# response += f" (current story goal: {prompt})"
|
||||
message = DirectorMessage(response, source=character.name)
|
||||
emit("director", message, character=character)
|
||||
self.scene.push_history(message)
|
||||
@@ -204,24 +255,38 @@ class DirectorAgent(Agent):
|
||||
|
||||
@set_processing
|
||||
async def persist_character(
|
||||
self,
|
||||
name:str,
|
||||
content:str = None,
|
||||
attributes:str = None,
|
||||
self,
|
||||
name: str,
|
||||
content: str = None,
|
||||
attributes: str = None,
|
||||
):
|
||||
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
self.scene.log.debug("persist_character", name=name)
|
||||
|
||||
character = self.scene.Character(name=name)
|
||||
character.color = random.choice(['#F08080', '#FFD700', '#90EE90', '#ADD8E6', '#DDA0DD', '#FFB6C1', '#FAFAD2', '#D3D3D3', '#B0E0E6', '#FFDEAD'])
|
||||
character.color = random.choice(
|
||||
[
|
||||
"#F08080",
|
||||
"#FFD700",
|
||||
"#90EE90",
|
||||
"#ADD8E6",
|
||||
"#DDA0DD",
|
||||
"#FFB6C1",
|
||||
"#FAFAD2",
|
||||
"#D3D3D3",
|
||||
"#B0E0E6",
|
||||
"#FFDEAD",
|
||||
]
|
||||
)
|
||||
|
||||
if not attributes:
|
||||
attributes = await world_state.extract_character_sheet(name=name, text=content)
|
||||
attributes = await world_state.extract_character_sheet(
|
||||
name=name, text=content
|
||||
)
|
||||
else:
|
||||
attributes = world_state._parse_character_sheet(attributes)
|
||||
|
||||
|
||||
self.scene.log.debug("persist_character", attributes=attributes)
|
||||
|
||||
character.base_attributes = attributes
|
||||
@@ -232,35 +297,54 @@ class DirectorAgent(Agent):
|
||||
|
||||
self.scene.log.debug("persist_character", description=description)
|
||||
|
||||
actor = self.scene.Actor(character=character, agent=instance.get_agent("conversation"))
|
||||
actor = self.scene.Actor(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
|
||||
await self.scene.add_actor(actor)
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@set_processing
|
||||
async def update_content_context(self, content:str=None, extra_choices:list[str]=None):
|
||||
|
||||
async def update_content_context(
|
||||
self, content: str = None, extra_choices: list[str] = None
|
||||
):
|
||||
if not content:
|
||||
content = "\n".join(self.scene.context_history(sections=False, min_dialogue=25, budget=2048))
|
||||
|
||||
response = await Prompt.request("world_state.determine-content-context", self.client, "analyze_freeform", vars={
|
||||
"content": content,
|
||||
"extra_choices": extra_choices or [],
|
||||
})
|
||||
|
||||
content = "\n".join(
|
||||
self.scene.context_history(sections=False, min_dialogue=25, budget=2048)
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.determine-content-context",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars={
|
||||
"content": content,
|
||||
"extra_choices": extra_choices or [],
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.context = response.strip()
|
||||
self.scene.emit_status()
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
|
||||
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
|
||||
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
log.debug(
|
||||
"inject_prompt_paramters",
|
||||
prompt_param=prompt_param,
|
||||
kind=kind,
|
||||
agent_function_name=agent_function_name,
|
||||
)
|
||||
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += character_names + ["#"]
|
||||
if agent_function_name == "update_content_context":
|
||||
prompt_param["extra_stopping_strings"] += ["\n"]
|
||||
|
||||
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
|
||||
return True
|
||||
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
return True
|
||||
|
||||
@@ -1,30 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.data_objects as data_objects
|
||||
import talemate.util as util
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
|
||||
import time
|
||||
import re
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.agents.narrator import NarratorAgentEmission
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
|
||||
log = structlog.get_logger("talemate.agents.editor")
|
||||
|
||||
|
||||
@register()
|
||||
class EditorAgent(Agent):
|
||||
"""
|
||||
@@ -35,175 +35,195 @@ class EditorAgent(Agent):
|
||||
|
||||
agent_type = "editor"
|
||||
verbose_name = "Editor"
|
||||
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"edit_dialogue": AgentAction(enabled=False, label="Edit dialogue", description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue."),
|
||||
"fix_exposition": AgentAction(enabled=True, label="Fix exposition", description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.", config={
|
||||
"narrator": AgentActionConfig(type="bool", label="Fix narrator messages", description="Will attempt to fix exposition issues in narrator messages", value=True),
|
||||
}),
|
||||
"add_detail": AgentAction(enabled=False, label="Add detail", description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.")
|
||||
"edit_dialogue": AgentAction(
|
||||
enabled=False,
|
||||
label="Edit dialogue",
|
||||
description="Will attempt to improve the quality of dialogue based on the character and scene. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
"fix_exposition": AgentAction(
|
||||
enabled=True,
|
||||
label="Fix exposition",
|
||||
description="Will attempt to fix exposition and emotes, making sure they are displayed in italics. Runs automatically after each AI dialogue.",
|
||||
config={
|
||||
"narrator": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Fix narrator messages",
|
||||
description="Will attempt to fix exposition issues in narrator messages",
|
||||
value=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
"add_detail": AgentAction(
|
||||
enabled=False,
|
||||
label="Add detail",
|
||||
description="Will attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
|
||||
talemate.emit.async_signals.get("agent.narrator.generated").connect(self.on_narrator_generated)
|
||||
|
||||
async def on_conversation_generated(self, emission:ConversationAgentEmission):
|
||||
talemate.emit.async_signals.get("agent.conversation.generated").connect(
|
||||
self.on_conversation_generated
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.narrator.generated").connect(
|
||||
self.on_narrator_generated
|
||||
)
|
||||
|
||||
async def on_conversation_generated(self, emission: ConversationAgentEmission):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
log.info("editing conversation", emission=emission)
|
||||
|
||||
|
||||
edited = []
|
||||
for text in emission.generation:
|
||||
edit = await self.add_detail(text, emission.character)
|
||||
|
||||
edit = await self.edit_conversation(edit, emission.character)
|
||||
|
||||
edit = await self.fix_exposition(edit, emission.character)
|
||||
|
||||
|
||||
edit = await self.add_detail(
|
||||
text,
|
||||
emission.character
|
||||
)
|
||||
|
||||
edit = await self.edit_conversation(
|
||||
edit,
|
||||
emission.character
|
||||
)
|
||||
|
||||
edit = await self.fix_exposition(
|
||||
edit,
|
||||
emission.character
|
||||
)
|
||||
|
||||
edited.append(edit)
|
||||
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
async def on_narrator_generated(self, emission:NarratorAgentEmission):
|
||||
|
||||
async def on_narrator_generated(self, emission: NarratorAgentEmission):
|
||||
"""
|
||||
Called when a narrator message is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
log.info("editing narrator", emission=emission)
|
||||
|
||||
|
||||
edited = []
|
||||
|
||||
|
||||
for text in emission.generation:
|
||||
edit = await self.fix_exposition_on_narrator(text)
|
||||
edited.append(edit)
|
||||
|
||||
|
||||
emission.generation = edited
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def edit_conversation(self, content:str, character:Character):
|
||||
async def edit_conversation(self, content: str, character: Character):
|
||||
"""
|
||||
Edits a conversation
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["edit_dialogue"].enabled:
|
||||
return content
|
||||
|
||||
response = await Prompt.request("editor.edit-dialogue", self.client, "edit_dialogue", vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length
|
||||
})
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.edit-dialogue",
|
||||
self.client,
|
||||
"edit_dialogue",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
response = response.split("[end]")[0]
|
||||
|
||||
|
||||
response = util.replace_exposition_markers(response)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def fix_exposition(self, content:str, character:Character):
|
||||
async def fix_exposition(self, content: str, character: Character):
|
||||
"""
|
||||
Edits a text to make sure all narrative exposition and emotes is encased in *
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["fix_exposition"].enabled:
|
||||
return content
|
||||
|
||||
|
||||
if not character.is_player:
|
||||
if '"' not in content and '*' not in content:
|
||||
if '"' not in content and "*" not in content:
|
||||
content = util.strip_partial_sentences(content)
|
||||
character_prefix = f"{character.name}: "
|
||||
message = content.split(character_prefix)[1]
|
||||
content = f"{character_prefix}*{message.strip('*')}*"
|
||||
return content
|
||||
elif '"' in content:
|
||||
|
||||
# silly hack to clean up some LLMs that always start with a quote
|
||||
# even though the immediate next thing is a narration (indicated by *)
|
||||
content = content.replace(f"{character.name}: \"*", f"{character.name}: *")
|
||||
|
||||
content = util.clean_dialogue(content, main_name=character.name)
|
||||
content = content.replace(
|
||||
f'{character.name}: "*', f"{character.name}: *"
|
||||
)
|
||||
|
||||
content = util.clean_dialogue(content, main_name=character.name)
|
||||
content = util.strip_partial_sentences(content)
|
||||
content = util.ensure_dialog_format(content, talking_character=character.name)
|
||||
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@set_processing
|
||||
async def fix_exposition_on_narrator(self, content:str):
|
||||
|
||||
async def fix_exposition_on_narrator(self, content: str):
|
||||
if not self.actions["fix_exposition"].enabled:
|
||||
return content
|
||||
|
||||
|
||||
if not self.actions["fix_exposition"].config["narrator"].value:
|
||||
return content
|
||||
|
||||
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
|
||||
if '"' not in content:
|
||||
content = f"*{content.strip('*')}*"
|
||||
else:
|
||||
content = util.ensure_dialog_format(content)
|
||||
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@set_processing
|
||||
async def add_detail(self, content:str, character:Character):
|
||||
async def add_detail(self, content: str, character: Character):
|
||||
"""
|
||||
Edits a text to increase its length and add extra detail and exposition
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["add_detail"].enabled:
|
||||
return content
|
||||
|
||||
response = await Prompt.request("editor.add-detail", self.client, "edit_add_detail", vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length
|
||||
})
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"editor.add-detail",
|
||||
self.client,
|
||||
"edit_add_detail",
|
||||
vars={
|
||||
"content": content,
|
||||
"character": character,
|
||||
"scene": self.scene,
|
||||
"max_length": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
response = util.replace_exposition_markers(response)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
from chromadb.config import Settings
|
||||
|
||||
import talemate.events as events
|
||||
import talemate.util as util
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.config import load_config
|
||||
from talemate.context import scene_is_loading
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.context import scene_is_loading
|
||||
from talemate.config import load_config
|
||||
from talemate.agents.base import set_processing
|
||||
import structlog
|
||||
import shutil
|
||||
import functools
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
@@ -30,17 +32,18 @@ if not chromadb:
|
||||
|
||||
from .base import Agent
|
||||
|
||||
|
||||
class MemoryDocument(str):
|
||||
|
||||
def __new__(cls, text, meta, id, raw):
|
||||
inst = super().__new__(cls, text)
|
||||
|
||||
inst.meta = meta
|
||||
inst.id = id
|
||||
inst.raw = raw
|
||||
|
||||
|
||||
return inst
|
||||
|
||||
|
||||
class MemoryAgent(Agent):
|
||||
"""
|
||||
An agent that can be used to maintain and access a memory of the world
|
||||
@@ -52,10 +55,11 @@ class MemoryAgent(Agent):
|
||||
|
||||
@property
|
||||
def readonly(self):
|
||||
|
||||
if scene_is_loading.get() and not getattr(self.scene, "_memory_never_persisted", False):
|
||||
if scene_is_loading.get() and not getattr(
|
||||
self.scene, "_memory_never_persisted", False
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
@@ -72,9 +76,9 @@ class MemoryAgent(Agent):
|
||||
self.memory_tracker = {}
|
||||
self.config = load_config()
|
||||
self._ready_to_add = False
|
||||
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
|
||||
def on_config_saved(self, event):
|
||||
openai_key = self.openai_api_key
|
||||
self.config = load_config()
|
||||
@@ -92,35 +96,68 @@ class MemoryAgent(Agent):
|
||||
raise NotImplementedError()
|
||||
|
||||
@set_processing
|
||||
async def add(self, text, character=None, uid=None, ts:str=None, **kwargs):
|
||||
async def add(self, text, character=None, uid=None, ts: str = None, **kwargs):
|
||||
if not text:
|
||||
return
|
||||
if self.readonly:
|
||||
log.debug("memory agent", status="readonly")
|
||||
return
|
||||
|
||||
|
||||
while not self._ready_to_add:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
log.debug("memory agent add", text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
|
||||
|
||||
|
||||
log.debug(
|
||||
"memory agent add",
|
||||
text=text[:50],
|
||||
character=character,
|
||||
uid=uid,
|
||||
ts=ts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
|
||||
try:
|
||||
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs),
|
||||
)
|
||||
except AttributeError as e:
|
||||
# not sure how this sometimes happens.
|
||||
# chromadb model None
|
||||
# race condition because we are forcing async context onto it?
|
||||
|
||||
log.error("memory agent", error="failed to add memory", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
|
||||
|
||||
log.error(
|
||||
"memory agent",
|
||||
error="failed to add memory",
|
||||
details=e,
|
||||
text=text[:50],
|
||||
character=character,
|
||||
uid=uid,
|
||||
ts=ts,
|
||||
**kwargs,
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
try:
|
||||
await loop.run_in_executor(None, functools.partial(self._add, text, character, uid=uid, ts=ts, **kwargs))
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self._add, text, character, uid=uid, ts=ts, **kwargs
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error("memory agent", error="failed to add memory (retried)", details=e, text=text[:50], character=character, uid=uid, ts=ts, **kwargs)
|
||||
log.error(
|
||||
"memory agent",
|
||||
error="failed to add memory (retried)",
|
||||
details=e,
|
||||
text=text[:50],
|
||||
character=character,
|
||||
uid=uid,
|
||||
ts=ts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _add(self, text, character=None, ts:str=None, **kwargs):
|
||||
def _add(self, text, character=None, ts: str = None, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@set_processing
|
||||
@@ -131,44 +168,46 @@ class MemoryAgent(Agent):
|
||||
|
||||
while not self._ready_to_add:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
log.debug("memory agent add many", len=len(objects))
|
||||
|
||||
|
||||
log.debug("memory agent add many", len=len(objects))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._add_many, objects)
|
||||
|
||||
|
||||
def _add_many(self, objects: list[dict]):
|
||||
"""
|
||||
Add multiple objects to the memory
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _delete(self, meta:dict):
|
||||
def _delete(self, meta: dict):
|
||||
"""
|
||||
Delete an object from the memory
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def delete(self, meta:dict):
|
||||
async def delete(self, meta: dict):
|
||||
"""
|
||||
Delete an object from the memory
|
||||
"""
|
||||
if self.readonly:
|
||||
log.debug("memory agent", status="readonly")
|
||||
return
|
||||
|
||||
|
||||
while not self._ready_to_add:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._delete, meta)
|
||||
|
||||
@set_processing
|
||||
async def get(self, text, character=None, **query):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
return await loop.run_in_executor(None, functools.partial(self._get, text, character, **query))
|
||||
|
||||
return await loop.run_in_executor(
|
||||
None, functools.partial(self._get, text, character, **query)
|
||||
)
|
||||
|
||||
def _get(self, text, character=None, **query):
|
||||
raise NotImplementedError()
|
||||
@@ -177,12 +216,14 @@ class MemoryAgent(Agent):
|
||||
async def get_document(self, id):
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, self._get_document, id)
|
||||
|
||||
|
||||
def _get_document(self, id):
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_archive_add(self, event: events.ArchiveEvent):
|
||||
asyncio.ensure_future(self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history"))
|
||||
asyncio.ensure_future(
|
||||
self.add(event.text, uid=event.memory_id, ts=event.ts, typ="history")
|
||||
)
|
||||
|
||||
def on_character_state(self, event: events.CharacterStateEvent):
|
||||
asyncio.ensure_future(
|
||||
@@ -222,10 +263,10 @@ class MemoryAgent(Agent):
|
||||
"""
|
||||
|
||||
memory_context = []
|
||||
|
||||
|
||||
if not query:
|
||||
return memory_context
|
||||
|
||||
|
||||
for memory in await self.get(query):
|
||||
if memory in memory_context:
|
||||
continue
|
||||
@@ -239,17 +280,26 @@ class MemoryAgent(Agent):
|
||||
break
|
||||
return memory_context
|
||||
|
||||
async def query(self, query:str, max_tokens:int=1000, filter:Callable=lambda x:True, **where):
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
max_tokens: int = 1000,
|
||||
filter: Callable = lambda x: True,
|
||||
**where,
|
||||
):
|
||||
"""
|
||||
Get the character memory context for a given character
|
||||
"""
|
||||
|
||||
try:
|
||||
return (await self.multi_query([query], max_tokens=max_tokens, filter=filter, **where))[0]
|
||||
return (
|
||||
await self.multi_query(
|
||||
[query], max_tokens=max_tokens, filter=filter, **where
|
||||
)
|
||||
)[0]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
||||
async def multi_query(
|
||||
self,
|
||||
queries: list[str],
|
||||
@@ -258,7 +308,7 @@ class MemoryAgent(Agent):
|
||||
filter: Callable = lambda x: True,
|
||||
formatter: Callable = lambda x: x,
|
||||
limit: int = 10,
|
||||
**where
|
||||
**where,
|
||||
):
|
||||
"""
|
||||
Get the character memory context for a given character
|
||||
@@ -266,10 +316,9 @@ class MemoryAgent(Agent):
|
||||
|
||||
memory_context = []
|
||||
for query in queries:
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
|
||||
i = 0
|
||||
for memory in await self.get(formatter(query), limit=limit, **where):
|
||||
if memory in memory_context:
|
||||
@@ -296,15 +345,13 @@ from .registry import register
|
||||
|
||||
@register(condition=lambda: chromadb is not None)
|
||||
class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
requires_llm_client = False
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return False
|
||||
|
||||
|
||||
if getattr(self, "db_client", None):
|
||||
return True
|
||||
return False
|
||||
@@ -313,80 +360,84 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def status(self):
|
||||
if self.ready:
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return "error"
|
||||
|
||||
|
||||
return "waiting"
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
|
||||
if self.embeddings == "openai" and not self.openai_api_key:
|
||||
return "No OpenAI API key set"
|
||||
|
||||
|
||||
return f"ChromaDB: {self.embeddings}"
|
||||
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""
|
||||
Returns which embeddings to use
|
||||
|
||||
|
||||
will read from TM_CHROMADB_EMBEDDINGS env variable and default to 'default' using
|
||||
the default embeddings specified by chromadb.
|
||||
|
||||
|
||||
other values are
|
||||
|
||||
|
||||
- openai: use openai embeddings
|
||||
- instructor: use instructor embeddings
|
||||
|
||||
|
||||
for `openai`:
|
||||
|
||||
|
||||
you will also need to provide an `OPENAI_API_KEY` env variable
|
||||
|
||||
|
||||
for `instructor`:
|
||||
|
||||
|
||||
you will also need to provide which instructor model to use with the `TM_INSTRUCTOR_MODEL` env variable, which defaults to hkunlp/instructor-xl
|
||||
|
||||
|
||||
additionally you can provide the `TM_INSTRUCTOR_DEVICE` env variable to specify which device to use, which defaults to cpu
|
||||
"""
|
||||
|
||||
|
||||
embeddings = self.config.get("chromadb").get("embeddings")
|
||||
|
||||
assert embeddings in ["default", "openai", "instructor"], f"Unknown embeddings {embeddings}"
|
||||
|
||||
|
||||
assert embeddings in [
|
||||
"default",
|
||||
"openai",
|
||||
"instructor",
|
||||
], f"Unknown embeddings {embeddings}"
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@property
|
||||
def USE_OPENAI(self):
|
||||
return self.embeddings == "openai"
|
||||
|
||||
|
||||
@property
|
||||
def USE_INSTRUCTOR(self):
|
||||
return self.embeddings == "instructor"
|
||||
|
||||
|
||||
@property
|
||||
def db_name(self):
|
||||
return getattr(self, "collection_name", "<unnamed>")
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai",{}).get("api_key")
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
def make_collection_name(self, scene):
|
||||
|
||||
if self.USE_OPENAI:
|
||||
suffix = "-openai"
|
||||
elif self.USE_INSTRUCTOR:
|
||||
suffix = "-instructor"
|
||||
model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
|
||||
model = self.config.get("chromadb").get(
|
||||
"instructor_model", "hkunlp/instructor-xl"
|
||||
)
|
||||
if "xl" in model:
|
||||
suffix += "-xl"
|
||||
elif "large" in model:
|
||||
suffix += "-large"
|
||||
else:
|
||||
suffix = ""
|
||||
|
||||
|
||||
return f"{scene.memory_id}-tm{suffix}"
|
||||
|
||||
async def count(self):
|
||||
@@ -399,9 +450,8 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
await loop.run_in_executor(None, self._set_db)
|
||||
|
||||
def _set_db(self):
|
||||
|
||||
self._ready_to_add = False
|
||||
|
||||
|
||||
if not getattr(self, "db_client", None):
|
||||
log.info("chromadb agent", status="setting up db client to persistent db")
|
||||
self.db_client = chromadb.PersistentClient(
|
||||
@@ -409,49 +459,60 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
)
|
||||
|
||||
openai_key = self.openai_api_key
|
||||
|
||||
|
||||
self.collection_name = collection_name = self.make_collection_name(self.scene)
|
||||
|
||||
log.info("chromadb agent", status="setting up db", collection_name=collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="setting up db", collection_name=collection_name
|
||||
)
|
||||
|
||||
if self.USE_OPENAI:
|
||||
|
||||
if not openai_key:
|
||||
raise ValueError("You must provide an the openai ai key in the config if you want to use it for chromadb embeddings")
|
||||
|
||||
raise ValueError(
|
||||
"You must provide an the openai ai key in the config if you want to use it for chromadb embeddings"
|
||||
)
|
||||
|
||||
log.info(
|
||||
"crhomadb", status="using openai", openai_key=openai_key[:5] + "..."
|
||||
)
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key = openai_key,
|
||||
api_key=openai_key,
|
||||
model_name="text-embedding-ada-002",
|
||||
)
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=openai_ef
|
||||
)
|
||||
elif self.USE_INSTRUCTOR:
|
||||
|
||||
instructor_device = self.config.get("chromadb").get("instructor_device", "cpu")
|
||||
instructor_model = self.config.get("chromadb").get("instructor_model", "hkunlp/instructor-xl")
|
||||
|
||||
log.info("chromadb", status="using instructor", model=instructor_model, device=instructor_device)
|
||||
|
||||
instructor_device = self.config.get("chromadb").get(
|
||||
"instructor_device", "cpu"
|
||||
)
|
||||
instructor_model = self.config.get("chromadb").get(
|
||||
"instructor_model", "hkunlp/instructor-xl"
|
||||
)
|
||||
|
||||
log.info(
|
||||
"chromadb",
|
||||
status="using instructor",
|
||||
model=instructor_model,
|
||||
device=instructor_device,
|
||||
)
|
||||
|
||||
# ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2")
|
||||
ef = embedding_functions.InstructorEmbeddingFunction(
|
||||
model_name=instructor_model, device=instructor_device
|
||||
)
|
||||
|
||||
|
||||
log.info("chromadb", status="embedding function ready")
|
||||
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef
|
||||
)
|
||||
|
||||
|
||||
log.info("chromadb", status="instructor db ready")
|
||||
else:
|
||||
log.info("chromadb", status="using default embeddings")
|
||||
self.db = self.db_client.get_or_create_collection(collection_name)
|
||||
|
||||
|
||||
self.scene._memory_never_persisted = self.db.count() == 0
|
||||
log.info("chromadb agent", status="db ready")
|
||||
self._ready_to_add = True
|
||||
@@ -459,17 +520,21 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def clear_db(self):
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="clearing db", collection_name=self.collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="clearing db", collection_name=self.collection_name
|
||||
)
|
||||
|
||||
self.db.delete(where={"source": "talemate"})
|
||||
|
||||
|
||||
def drop_db(self):
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="dropping db", collection_name=self.collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="dropping db", collection_name=self.collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
self.db_client.delete_collection(self.collection_name)
|
||||
except ValueError as exc:
|
||||
@@ -479,31 +544,43 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def close_db(self, scene):
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="closing db", collection_name=self.collection_name)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="closing db", collection_name=self.collection_name
|
||||
)
|
||||
|
||||
if not scene.saved and not scene.saved_memory_session_id:
|
||||
# scene was never saved so we can discard the memory
|
||||
collection_name = self.make_collection_name(scene)
|
||||
log.info("chromadb agent", status="discarding memory", collection_name=collection_name)
|
||||
log.info(
|
||||
"chromadb agent",
|
||||
status="discarding memory",
|
||||
collection_name=collection_name,
|
||||
)
|
||||
try:
|
||||
self.db_client.delete_collection(collection_name)
|
||||
except ValueError as exc:
|
||||
log.error("chromadb agent", error="failed to delete collection", details=exc)
|
||||
log.error(
|
||||
"chromadb agent", error="failed to delete collection", details=exc
|
||||
)
|
||||
elif not scene.saved:
|
||||
# scene was saved but memory was never persisted
|
||||
# so we need to remove the memory from the db
|
||||
self._remove_unsaved_memory()
|
||||
|
||||
|
||||
self.db = None
|
||||
|
||||
def _add(self, text, character=None, uid=None, ts:str=None, **kwargs):
|
||||
|
||||
def _add(self, text, character=None, uid=None, ts: str = None, **kwargs):
|
||||
metadatas = []
|
||||
ids = []
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if character:
|
||||
meta = {"character": character.name, "source": "talemate", "session": scene.memory_session_id}
|
||||
meta = {
|
||||
"character": character.name,
|
||||
"source": "talemate",
|
||||
"session": scene.memory_session_id,
|
||||
}
|
||||
if ts:
|
||||
meta["ts"] = ts
|
||||
meta.update(kwargs)
|
||||
@@ -513,7 +590,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
id = uid or f"{character.name}-{self.memory_tracker[character.name]}"
|
||||
ids = [id]
|
||||
else:
|
||||
meta = {"character": "__narrator__", "source": "talemate", "session": scene.memory_session_id}
|
||||
meta = {
|
||||
"character": "__narrator__",
|
||||
"source": "talemate",
|
||||
"session": scene.memory_session_id,
|
||||
}
|
||||
if ts:
|
||||
meta["ts"] = ts
|
||||
meta.update(kwargs)
|
||||
@@ -523,17 +604,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
id = uid or f"__narrator__-{self.memory_tracker['__narrator__']}"
|
||||
ids = [id]
|
||||
|
||||
#log.debug("chromadb agent add", text=text, meta=meta, id=id)
|
||||
# log.debug("chromadb agent add", text=text, meta=meta, id=id)
|
||||
|
||||
self.db.upsert(documents=[text], metadatas=metadatas, ids=ids)
|
||||
|
||||
def _add_many(self, objects: list[dict]):
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if not objects:
|
||||
return
|
||||
|
||||
@@ -552,52 +632,50 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
ids.append(uid)
|
||||
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
|
||||
|
||||
def _delete(self, meta:dict):
|
||||
|
||||
def _delete(self, meta: dict):
|
||||
if "ids" in meta:
|
||||
log.debug("chromadb agent delete", ids=meta["ids"])
|
||||
self.db.delete(ids=meta["ids"])
|
||||
return
|
||||
|
||||
where = {"$and": [{k:v} for k,v in meta.items()]}
|
||||
|
||||
where = {"$and": [{k: v} for k, v in meta.items()]}
|
||||
self.db.delete(where=where)
|
||||
log.debug("chromadb agent delete", meta=meta, where=where)
|
||||
|
||||
def _get(self, text, character=None, limit:int=15, **kwargs):
|
||||
def _get(self, text, character=None, limit: int = 15, **kwargs):
|
||||
where = {}
|
||||
|
||||
|
||||
# this doesn't work because chromadb currently doesn't match
|
||||
# non existing fields with $ne (or so it seems)
|
||||
# where.setdefault("$and", [{"pin_only": {"$ne": True}}])
|
||||
|
||||
|
||||
where.setdefault("$and", [])
|
||||
|
||||
|
||||
character_filtered = False
|
||||
|
||||
for k,v in kwargs.items():
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if k == "character":
|
||||
character_filtered = True
|
||||
where["$and"].append({k: v})
|
||||
|
||||
|
||||
if character and not character_filtered:
|
||||
where["$and"].append({"character": character.name})
|
||||
|
||||
|
||||
if len(where["$and"]) == 1:
|
||||
where = where["$and"][0]
|
||||
elif not where["$and"]:
|
||||
where = None
|
||||
|
||||
#log.debug("crhomadb agent get", text=text, where=where)
|
||||
# log.debug("crhomadb agent get", text=text, where=where)
|
||||
|
||||
_results = self.db.query(query_texts=[text], where=where, n_results=limit)
|
||||
|
||||
#import json
|
||||
#print(json.dumps(_results["ids"], indent=2))
|
||||
#print(json.dumps(_results["distances"], indent=2))
|
||||
|
||||
|
||||
# import json
|
||||
# print(json.dumps(_results["ids"], indent=2))
|
||||
# print(json.dumps(_results["distances"], indent=2))
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
max_distance = 1.5
|
||||
if self.USE_INSTRUCTOR:
|
||||
max_distance = 1
|
||||
@@ -606,24 +684,24 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
for i in range(len(_results["distances"][0])):
|
||||
distance = _results["distances"][0][i]
|
||||
|
||||
|
||||
doc = _results["documents"][0][i]
|
||||
meta = _results["metadatas"][0][i]
|
||||
ts = meta.get("ts")
|
||||
|
||||
|
||||
# skip pin_only entries
|
||||
if meta.get("pin_only", False):
|
||||
continue
|
||||
|
||||
|
||||
if distance < max_distance:
|
||||
date_prefix = self.convert_ts_to_date_prefix(ts)
|
||||
raw = doc
|
||||
|
||||
|
||||
if date_prefix:
|
||||
doc = f"{date_prefix}: {doc}"
|
||||
|
||||
|
||||
doc = MemoryDocument(doc, meta, _results["ids"][0][i], raw)
|
||||
|
||||
|
||||
results.append(doc)
|
||||
else:
|
||||
break
|
||||
@@ -635,45 +713,55 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def convert_ts_to_date_prefix(self, ts):
|
||||
if not ts:
|
||||
return None
|
||||
try:
|
||||
return util.iso8601_diff_to_human(ts, self.scene.ts)
|
||||
except Exception as e:
|
||||
log.error("chromadb agent", error="failed to get date prefix", details=e, ts=ts, scene_ts=self.scene.ts)
|
||||
log.error(
|
||||
"chromadb agent",
|
||||
error="failed to get date prefix",
|
||||
details=e,
|
||||
ts=ts,
|
||||
scene_ts=self.scene.ts,
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_document(self, id) -> dict:
|
||||
result = self.db.get(ids=[id] if isinstance(id, str) else id)
|
||||
documents = {}
|
||||
|
||||
|
||||
for idx, doc in enumerate(result["documents"]):
|
||||
date_prefix = self.convert_ts_to_date_prefix(result["metadatas"][idx].get("ts"))
|
||||
date_prefix = self.convert_ts_to_date_prefix(
|
||||
result["metadatas"][idx].get("ts")
|
||||
)
|
||||
if date_prefix:
|
||||
doc = f"{date_prefix}: {doc}"
|
||||
documents[result["ids"][idx]] = MemoryDocument(doc, result["metadatas"][idx], result["ids"][idx], doc)
|
||||
|
||||
documents[result["ids"][idx]] = MemoryDocument(
|
||||
doc, result["metadatas"][idx], result["ids"][idx], doc
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@set_processing
|
||||
async def remove_unsaved_memory(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._remove_unsaved_memory)
|
||||
|
||||
|
||||
def _remove_unsaved_memory(self):
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if not scene.memory_session_id:
|
||||
return
|
||||
|
||||
|
||||
if scene.saved_memory_session_id == self.scene.memory_session_id:
|
||||
return
|
||||
|
||||
log.info("chromadb agent", status="removing unsaved memory", session_id=scene.memory_session_id)
|
||||
|
||||
|
||||
log.info(
|
||||
"chromadb agent",
|
||||
status="removing unsaved memory",
|
||||
session_id=scene.memory_session_id,
|
||||
)
|
||||
|
||||
self._delete({"session": scene.memory_session_id, "source": "talemate"})
|
||||
|
||||
|
||||
|
||||
@@ -1,41 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
import dataclasses
|
||||
import structlog
|
||||
import random
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
import talemate.emit.async_signals
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.agents.base import set_processing as _set_processing, Agent, AgentAction, AgentActionConfig, AgentEmission
|
||||
from talemate.agents.world_state import TimePassageEmission
|
||||
from talemate.scene_message import NarratorMessage
|
||||
from talemate.events import GameLoopActorIterEvent
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.client as client
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission
|
||||
from talemate.agents.base import set_processing as _set_processing
|
||||
from talemate.agents.world_state import TimePassageEmission
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopActorIterEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import NarratorMessage
|
||||
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Player, Character
|
||||
from talemate.tale_mate import Actor, Character, Player
|
||||
|
||||
log = structlog.get_logger("talemate.agents.narrator")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NarratorAgentEmission(AgentEmission):
|
||||
generation: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.narrator.generated"
|
||||
)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register("agent.narrator.generated")
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
|
||||
"""
|
||||
Custom decorator that emits the agent status as processing while the function
|
||||
is running and then emits the result of the function as a NarratorAgentEmission
|
||||
"""
|
||||
|
||||
|
||||
@_set_processing
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
response = await fn(self, *args, **kwargs)
|
||||
@@ -45,68 +48,70 @@ def set_processing(fn):
|
||||
)
|
||||
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
|
||||
return emission.generation[0]
|
||||
|
||||
wrapper.__name__ = fn.__name__
|
||||
return wrapper
|
||||
|
||||
|
||||
@register()
|
||||
class NarratorAgent(Agent):
|
||||
|
||||
|
||||
"""
|
||||
Handles narration of the story
|
||||
"""
|
||||
|
||||
|
||||
agent_type = "narrator"
|
||||
verbose_name = "Narrator"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: client.TaleMateClient,
|
||||
**kwargs,
|
||||
):
|
||||
self.client = client
|
||||
|
||||
|
||||
# agent actions
|
||||
|
||||
|
||||
self.actions = {
|
||||
"generation_override": AgentAction(
|
||||
enabled = True,
|
||||
label = "Generation Override",
|
||||
description = "Override generation parameters",
|
||||
config = {
|
||||
enabled=True,
|
||||
label="Generation Override",
|
||||
description="Override generation parameters",
|
||||
config={
|
||||
"instructions": AgentActionConfig(
|
||||
type="text",
|
||||
label="Instructions",
|
||||
value="Never wax poetic.",
|
||||
description="Extra instructions to give to the AI for narrative generation.",
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
enabled = True,
|
||||
label = "Auto Break Repetition",
|
||||
description = "Will attempt to automatically break AI repetition.",
|
||||
enabled=True,
|
||||
label="Auto Break Repetition",
|
||||
description="Will attempt to automatically break AI repetition.",
|
||||
),
|
||||
"narrate_time_passage": AgentAction(
|
||||
enabled=True,
|
||||
label="Narrate Time Passage",
|
||||
enabled=True,
|
||||
label="Narrate Time Passage",
|
||||
description="Whenever you indicate passage of time, narrate right after",
|
||||
config = {
|
||||
config={
|
||||
"ask_for_prompt": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Guide time narration via prompt",
|
||||
label="Guide time narration via prompt",
|
||||
description="Ask the user for a prompt to generate the time passage narration",
|
||||
value=True,
|
||||
)
|
||||
}
|
||||
},
|
||||
),
|
||||
"narrate_dialogue": AgentAction(
|
||||
enabled=False,
|
||||
label="Narrate after Dialogue",
|
||||
enabled=False,
|
||||
label="Narrate after Dialogue",
|
||||
description="Narrator will get a chance to narrate after every line of dialogue",
|
||||
config = {
|
||||
config={
|
||||
"ai_dialog": AgentActionConfig(
|
||||
type="number",
|
||||
label="AI Dialogue",
|
||||
label="AI Dialogue",
|
||||
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||
value=0.0,
|
||||
min=0.0,
|
||||
@@ -115,7 +120,7 @@ class NarratorAgent(Agent):
|
||||
),
|
||||
"player_dialog": AgentActionConfig(
|
||||
type="number",
|
||||
label="Player Dialogue",
|
||||
label="Player Dialogue",
|
||||
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||
value=0.1,
|
||||
min=0.0,
|
||||
@@ -124,34 +129,32 @@ class NarratorAgent(Agent):
|
||||
),
|
||||
"generate_dialogue": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Allow Dialogue in Narration",
|
||||
label="Allow Dialogue in Narration",
|
||||
description="Allow the narrator to generate dialogue in narration",
|
||||
value=False,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def extra_instructions(self):
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["instructions"].value
|
||||
return ""
|
||||
|
||||
|
||||
def clean_result(self, result):
|
||||
|
||||
"""
|
||||
Cleans the result of a narration
|
||||
"""
|
||||
|
||||
|
||||
result = result.strip().strip(":").strip()
|
||||
|
||||
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
|
||||
|
||||
character_names = [c.name for c in self.scene.get_characters()]
|
||||
|
||||
|
||||
|
||||
cleaned = []
|
||||
for line in result.split("\n"):
|
||||
for character_name in character_names:
|
||||
@@ -160,71 +163,83 @@ class NarratorAgent(Agent):
|
||||
cleaned.append(line)
|
||||
|
||||
result = "\n".join(cleaned)
|
||||
#result = util.strip_partial_sentences(result)
|
||||
# result = util.strip_partial_sentences(result)
|
||||
return result
|
||||
|
||||
def connect(self, scene):
|
||||
|
||||
"""
|
||||
Connect to signals
|
||||
"""
|
||||
|
||||
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
|
||||
talemate.emit.async_signals.get("agent.world_state.time").connect(
|
||||
self.on_time_passage
|
||||
)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
|
||||
|
||||
async def on_time_passage(self, event:TimePassageEmission):
|
||||
|
||||
async def on_time_passage(self, event: TimePassageEmission):
|
||||
"""
|
||||
Handles time passage narration, if enabled
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["narrate_time_passage"].enabled:
|
||||
return
|
||||
|
||||
response = await self.narrate_time_passage(event.duration, event.human_duration, event.narrative)
|
||||
narrator_message = NarratorMessage(response, source=f"narrate_time_passage:{event.duration};{event.narrative}")
|
||||
|
||||
response = await self.narrate_time_passage(
|
||||
event.duration, event.human_duration, event.narrative
|
||||
)
|
||||
narrator_message = NarratorMessage(
|
||||
response, source=f"narrate_time_passage:{event.duration};{event.narrative}"
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
async def on_dialog(self, event:GameLoopActorIterEvent):
|
||||
|
||||
|
||||
async def on_dialog(self, event: GameLoopActorIterEvent):
|
||||
"""
|
||||
Handles dialogue narration, if enabled
|
||||
"""
|
||||
|
||||
|
||||
if not self.actions["narrate_dialogue"].enabled:
|
||||
return
|
||||
|
||||
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug("narrate on dialog", skip=True, had_passive_narration=event.game_loop.had_passive_narration)
|
||||
log.debug(
|
||||
"narrate on dialog",
|
||||
skip=True,
|
||||
had_passive_narration=event.game_loop.had_passive_narration,
|
||||
)
|
||||
return
|
||||
|
||||
narrate_on_ai_chance = self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||
narrate_on_player_chance = self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||
|
||||
narrate_on_ai_chance = (
|
||||
self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||
)
|
||||
narrate_on_player_chance = (
|
||||
self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||
)
|
||||
narrate_on_ai = random.random() < narrate_on_ai_chance
|
||||
narrate_on_player = random.random() < narrate_on_player_chance
|
||||
|
||||
log.debug(
|
||||
"narrate on dialog",
|
||||
narrate_on_ai=narrate_on_ai,
|
||||
narrate_on_ai_chance=narrate_on_ai_chance,
|
||||
"narrate on dialog",
|
||||
narrate_on_ai=narrate_on_ai,
|
||||
narrate_on_ai_chance=narrate_on_ai_chance,
|
||||
narrate_on_player=narrate_on_player,
|
||||
narrate_on_player_chance=narrate_on_player_chance,
|
||||
)
|
||||
|
||||
|
||||
if event.actor.character.is_player and not narrate_on_player:
|
||||
return
|
||||
|
||||
|
||||
if not event.actor.character.is_player and not narrate_on_ai:
|
||||
return
|
||||
|
||||
|
||||
response = await self.narrate_after_dialogue(event.actor.character)
|
||||
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
|
||||
narrator_message = NarratorMessage(
|
||||
response, source=f"narrate_dialogue:{event.actor.character.name}"
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
|
||||
event.game_loop.had_passive_narration = True
|
||||
|
||||
@set_processing
|
||||
@@ -237,22 +252,22 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-scene",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
response = response.strip("*")
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
|
||||
response = f"*{response.strip('*')}*"
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def progress_story(self, narrative_direction:str=None):
|
||||
async def progress_story(self, narrative_direction: str = None):
|
||||
"""
|
||||
Narrate the scene
|
||||
"""
|
||||
@@ -260,18 +275,20 @@ class NarratorAgent(Agent):
|
||||
scene = self.scene
|
||||
pc = scene.get_player_character()
|
||||
npcs = list(scene.get_npc_characters())
|
||||
npc_names= ", ".join([npc.name for npc in npcs])
|
||||
|
||||
npc_names = ", ".join([npc.name for npc in npcs])
|
||||
|
||||
if narrative_direction is None:
|
||||
narrative_direction = "Slightly move the current scene forward."
|
||||
|
||||
self.scene.log.info("narrative_direction", narrative_direction=narrative_direction)
|
||||
|
||||
self.scene.log.info(
|
||||
"narrative_direction", narrative_direction=narrative_direction
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.narrate-progress",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"narrative_direction": narrative_direction,
|
||||
@@ -279,7 +296,7 @@ class NarratorAgent(Agent):
|
||||
"npcs": npcs,
|
||||
"npc_names": npc_names,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.info("progress_story", response=response)
|
||||
@@ -291,11 +308,13 @@ class NarratorAgent(Agent):
|
||||
if response.count("*") % 2 != 0:
|
||||
response = response.replace("*", "")
|
||||
response = f"*{response}*"
|
||||
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def narrate_query(self, query:str, at_the_end:bool=False, as_narrative:bool=True):
|
||||
async def narrate_query(
|
||||
self, query: str, at_the_end: bool = False, as_narrative: bool = True
|
||||
):
|
||||
"""
|
||||
Narrate a specific query
|
||||
"""
|
||||
@@ -303,21 +322,21 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-query",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"query": query,
|
||||
"at_the_end": at_the_end,
|
||||
"as_narrative": as_narrative,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
log.info("narrate_query", response=response)
|
||||
response = self.clean_result(response.strip())
|
||||
log.info("narrate_query (after clean)", response=response)
|
||||
if as_narrative:
|
||||
response = f"*{response}*"
|
||||
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@@ -330,12 +349,12 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-character",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"character": character,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip())
|
||||
@@ -345,54 +364,55 @@ class NarratorAgent(Agent):
|
||||
|
||||
@set_processing
|
||||
async def augment_context(self):
|
||||
|
||||
"""
|
||||
Takes a context history generated via scene.context_history() and augments it with additional information
|
||||
by asking and answering questions with help from the long term memory.
|
||||
"""
|
||||
memory = self.scene.get_helper("memory").agent
|
||||
|
||||
|
||||
questions = await Prompt.request(
|
||||
"narrator.context-questions",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
self.scene.log.info("context_questions", questions=questions)
|
||||
|
||||
|
||||
questions = [q for q in questions.split("\n") if q.strip()]
|
||||
|
||||
|
||||
memory_context = await memory.multi_query(
|
||||
questions, iterate=2, max_tokens=self.client.max_token_length - 1000
|
||||
)
|
||||
|
||||
|
||||
answers = await Prompt.request(
|
||||
"narrator.context-answers",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"memory": memory_context,
|
||||
"questions": questions,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
self.scene.log.info("context_answers", answers=answers)
|
||||
|
||||
|
||||
answers = [a for a in answers.split("\n") if a.strip()]
|
||||
|
||||
|
||||
# return questions and answers
|
||||
return list(zip(questions, answers))
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_time_passage(self, duration:str, time_passed:str, narrative:str):
|
||||
async def narrate_time_passage(
|
||||
self, duration: str, time_passed: str, narrative: str
|
||||
):
|
||||
"""
|
||||
Narrate a specific character
|
||||
"""
|
||||
@@ -401,26 +421,25 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-time-passage",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"duration": duration,
|
||||
"time_passed": time_passed,
|
||||
"narrative": narrative,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.info("narrate_time_passage", response=response)
|
||||
|
||||
response = self.clean_result(response.strip())
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_after_dialogue(self, character:Character):
|
||||
async def narrate_after_dialogue(self, character: Character):
|
||||
"""
|
||||
Narrate after a line of dialogue
|
||||
"""
|
||||
@@ -429,22 +448,24 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-after-dialogue",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character": character,
|
||||
"last_line": str(self.scene.history[-1]),
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.info("narrate_after_dialogue", response=response)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
allow_dialogue = self.actions["narrate_dialogue"].config["generate_dialogue"].value
|
||||
|
||||
|
||||
allow_dialogue = (
|
||||
self.actions["narrate_dialogue"].config["generate_dialogue"].value
|
||||
)
|
||||
|
||||
if not allow_dialogue:
|
||||
response = response.split('"')[0].strip()
|
||||
response = response.replace("*", "")
|
||||
@@ -452,9 +473,11 @@ class NarratorAgent(Agent):
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_character_entry(self, character:Character, direction:str=None):
|
||||
async def narrate_character_entry(
|
||||
self, character: Character, direction: str = None
|
||||
):
|
||||
"""
|
||||
Narrate a character entering the scene
|
||||
"""
|
||||
@@ -463,22 +486,22 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-character-entry",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character": character,
|
||||
"direction": direction,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def narrate_character_exit(self, character:Character, direction:str=None):
|
||||
async def narrate_character_exit(self, character: Character, direction: str = None):
|
||||
"""
|
||||
Narrate a character exiting the scene
|
||||
"""
|
||||
@@ -487,20 +510,20 @@ class NarratorAgent(Agent):
|
||||
"narrator.narrate-character-exit",
|
||||
self.client,
|
||||
"narrate",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"character": character,
|
||||
"direction": direction,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip().strip("*"))
|
||||
response = f"*{response}*"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def action_to_narration(
|
||||
self,
|
||||
action_name: str,
|
||||
@@ -509,25 +532,35 @@ class NarratorAgent(Agent):
|
||||
):
|
||||
# calls self[action_name] and returns the result as a NarratorMessage
|
||||
# that is pushed to the history
|
||||
|
||||
|
||||
fn = getattr(self, action_name)
|
||||
narration = await fn(*args, **kwargs)
|
||||
narrator_message = NarratorMessage(narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":"))
|
||||
narrator_message = NarratorMessage(
|
||||
narration, source=f"{action_name}:{args[0] if args else ''}".rstrip(":")
|
||||
)
|
||||
self.scene.push_history(narrator_message)
|
||||
return narrator_message
|
||||
|
||||
|
||||
# LLM client related methods. These are called during or after the client
|
||||
|
||||
def inject_prompt_paramters(self, prompt_param: dict, kind: str, agent_function_name: str):
|
||||
log.debug("inject_prompt_paramters", prompt_param=prompt_param, kind=kind, agent_function_name=agent_function_name)
|
||||
|
||||
def inject_prompt_paramters(
|
||||
self, prompt_param: dict, kind: str, agent_function_name: str
|
||||
):
|
||||
log.debug(
|
||||
"inject_prompt_paramters",
|
||||
prompt_param=prompt_param,
|
||||
kind=kind,
|
||||
agent_function_name=agent_function_name,
|
||||
)
|
||||
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += character_names
|
||||
|
||||
def allow_repetition_break(self, kind: str, agent_function_name: str, auto:bool=False):
|
||||
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
if auto and not self.actions["auto_break_repetition"].enabled:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.data_objects as data_objects
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||
from talemate.events import GameLoopEvent
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
|
||||
import time
|
||||
import re
|
||||
|
||||
log = structlog.get_logger("talemate.agents.summarize")
|
||||
|
||||
|
||||
@register()
|
||||
class SummarizeAgent(Agent):
|
||||
"""
|
||||
@@ -36,7 +36,7 @@ class SummarizeAgent(Agent):
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
|
||||
|
||||
self.actions = {
|
||||
"archive": AgentAction(
|
||||
enabled=True,
|
||||
@@ -67,30 +67,27 @@ class SummarizeAgent(Agent):
|
||||
type="number",
|
||||
label="Use preceeding summaries to strengthen context",
|
||||
description="Number of entries",
|
||||
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
|
||||
note="Help the AI summarize by including the last few summaries as additional context. Some models may incorporate this context into the new summary directly, so if you find yourself with a bunch of similar history entries, try setting this to 0.",
|
||||
value=3,
|
||||
min=0,
|
||||
max=10,
|
||||
step=1,
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission:GameLoopEvent):
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
await self.build_archive(self.scene)
|
||||
|
||||
|
||||
|
||||
def clean_result(self, result):
|
||||
if "#" in result:
|
||||
result = result.split("#")[0]
|
||||
@@ -104,10 +101,10 @@ class SummarizeAgent(Agent):
|
||||
@set_processing
|
||||
async def build_archive(self, scene):
|
||||
end = None
|
||||
|
||||
|
||||
if not self.actions["archive"].enabled:
|
||||
return
|
||||
|
||||
|
||||
if not scene.archived_history:
|
||||
start = 0
|
||||
recent_entry = None
|
||||
@@ -118,14 +115,16 @@ class SummarizeAgent(Agent):
|
||||
# meaning we are still at the beginning of the scene
|
||||
start = 0
|
||||
else:
|
||||
start = recent_entry.get("end", 0)+1
|
||||
|
||||
start = recent_entry.get("end", 0) + 1
|
||||
|
||||
# if there is a recent entry we also collect the 3 most recentries
|
||||
# as extra context
|
||||
|
||||
|
||||
num_previous = self.actions["archive"].config["include_previous"].value
|
||||
if recent_entry and num_previous > 0:
|
||||
extra_context = "\n\n".join([entry["text"] for entry in scene.archived_history[-num_previous:]])
|
||||
extra_context = "\n\n".join(
|
||||
[entry["text"] for entry in scene.archived_history[-num_previous:]]
|
||||
)
|
||||
else:
|
||||
extra_context = None
|
||||
|
||||
@@ -133,36 +132,42 @@ class SummarizeAgent(Agent):
|
||||
dialogue_entries = []
|
||||
ts = "PT0S"
|
||||
time_passage_termination = False
|
||||
|
||||
|
||||
token_threshold = self.actions["archive"].config["threshold"].value
|
||||
|
||||
|
||||
log.debug("build_archive", start=start, recent_entry=recent_entry)
|
||||
|
||||
|
||||
if recent_entry:
|
||||
ts = recent_entry.get("ts", ts)
|
||||
|
||||
|
||||
for i in range(start, len(scene.history)):
|
||||
dialogue = scene.history[i]
|
||||
|
||||
#log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
|
||||
|
||||
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
|
||||
if isinstance(dialogue, DirectorMessage):
|
||||
if i == start:
|
||||
start += 1
|
||||
continue
|
||||
|
||||
|
||||
if isinstance(dialogue, TimePassageMessage):
|
||||
log.debug("build_archive", time_passage_message=dialogue)
|
||||
if i == start:
|
||||
ts = util.iso8601_add(ts, dialogue.ts)
|
||||
log.debug("build_archive", time_passage_message=dialogue, start=start, i=i, ts=ts)
|
||||
log.debug(
|
||||
"build_archive",
|
||||
time_passage_message=dialogue,
|
||||
start=start,
|
||||
i=i,
|
||||
ts=ts,
|
||||
)
|
||||
start += 1
|
||||
continue
|
||||
log.debug("build_archive", time_passage_message_termination=dialogue)
|
||||
time_passage_termination = True
|
||||
end = i - 1
|
||||
break
|
||||
|
||||
|
||||
tokens += util.count_tokens(dialogue)
|
||||
dialogue_entries.append(dialogue)
|
||||
if tokens > token_threshold: #
|
||||
@@ -172,39 +177,44 @@ class SummarizeAgent(Agent):
|
||||
if end is None:
|
||||
# nothing to archive yet
|
||||
return
|
||||
|
||||
log.debug("build_archive", start=start, end=end, ts=ts, time_passage_termination=time_passage_termination)
|
||||
|
||||
log.debug(
|
||||
"build_archive",
|
||||
start=start,
|
||||
end=end,
|
||||
ts=ts,
|
||||
time_passage_termination=time_passage_termination,
|
||||
)
|
||||
|
||||
# in order to summarize coherently, we need to determine if there is a favorable
|
||||
# cutoff point (e.g., the scene naturally ends or shifts meaninfully in the middle
|
||||
# of the dialogue)
|
||||
#
|
||||
# One way to do this is to check if the last line is a TimePassageMessage, which
|
||||
# indicates a scene change or a significant pause.
|
||||
#
|
||||
# indicates a scene change or a significant pause.
|
||||
#
|
||||
# If not, we can ask the AI to find a good point of
|
||||
# termination.
|
||||
|
||||
|
||||
if not time_passage_termination:
|
||||
|
||||
# No TimePassageMessage, so we need to ask the AI to find a good point of termination
|
||||
|
||||
|
||||
terminating_line = await self.analyze_dialoge(dialogue_entries)
|
||||
|
||||
if terminating_line:
|
||||
adjusted_dialogue = []
|
||||
for line in dialogue_entries:
|
||||
for line in dialogue_entries:
|
||||
if str(line) in terminating_line:
|
||||
break
|
||||
adjusted_dialogue.append(line)
|
||||
dialogue_entries = adjusted_dialogue
|
||||
end = start + len(dialogue_entries)-1
|
||||
|
||||
end = start + len(dialogue_entries) - 1
|
||||
|
||||
if dialogue_entries:
|
||||
summarized = await self.summarize(
|
||||
"\n".join(map(str, dialogue_entries)), extra_context=extra_context
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# AI has likely identified the first line as a scene change, so we can't summarize
|
||||
# just use the first line
|
||||
@@ -218,15 +228,20 @@ class SummarizeAgent(Agent):
|
||||
|
||||
@set_processing
|
||||
async def analyze_dialoge(self, dialogue):
|
||||
response = await Prompt.request("summarizer.analyze-dialogue", self.client, "analyze_freeform", vars={
|
||||
"dialogue": "\n".join(map(str, dialogue)),
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
})
|
||||
|
||||
response = await Prompt.request(
|
||||
"summarizer.analyze-dialogue",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars={
|
||||
"dialogue": "\n".join(map(str, dialogue)),
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response)
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize(
|
||||
self,
|
||||
@@ -239,33 +254,40 @@ class SummarizeAgent(Agent):
|
||||
Summarize the given text
|
||||
"""
|
||||
|
||||
response = await Prompt.request("summarizer.summarize-dialogue", self.client, "summarize", vars={
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"summarization_method": self.actions["archive"].config["method"].value if method is None else method,
|
||||
"extra_context": extra_context or "",
|
||||
"extra_instructions": extra_instructions or "",
|
||||
})
|
||||
|
||||
self.scene.log.info("summarize", dialogue_length=len(text), summarized_length=len(response))
|
||||
response = await Prompt.request(
|
||||
"summarizer.summarize-dialogue",
|
||||
self.client,
|
||||
"summarize",
|
||||
vars={
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"summarization_method": self.actions["archive"].config["method"].value
|
||||
if method is None
|
||||
else method,
|
||||
"extra_context": extra_context or "",
|
||||
"extra_instructions": extra_instructions or "",
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.info(
|
||||
"summarize", dialogue_length=len(text), summarized_length=len(response)
|
||||
)
|
||||
|
||||
return self.clean_result(response)
|
||||
|
||||
|
||||
async def build_stepped_archive_for_level(self, level:int):
|
||||
|
||||
|
||||
async def build_stepped_archive_for_level(self, level: int):
|
||||
"""
|
||||
WIP - not yet used
|
||||
|
||||
|
||||
This will iterate over existing archived_history entries
|
||||
and stepped_archived_history entries and summarize based on time duration
|
||||
indicated between the entries.
|
||||
|
||||
|
||||
The lowest level of summarization (based on token threshold and any time passage)
|
||||
happens in build_archive. This method is for summarizing furhter levels based on
|
||||
long time pasages.
|
||||
|
||||
|
||||
Level 0: small timestap summarize (summarizes all token summarizations when time advances +1 day)
|
||||
Level 1: medium timestap summarize (summarizes all small timestep summarizations when time advances +1 week)
|
||||
Level 2: large timestap summarize (summarizes all medium timestep summarizations when time advances +1 month)
|
||||
@@ -273,7 +295,7 @@ class SummarizeAgent(Agent):
|
||||
Level 4: massive timestap summarize (summarizes all huge timestep summarizations when time advances +10 years)
|
||||
Level 5: epic timestap summarize (summarizes all massive timestep summarizations when time advances +100 years)
|
||||
and so on (increasing by a factor of 10 each time)
|
||||
|
||||
|
||||
```
|
||||
@dataclass
|
||||
class ArchiveEntry:
|
||||
@@ -282,35 +304,34 @@ class SummarizeAgent(Agent):
|
||||
end: int = None
|
||||
ts: str = None
|
||||
```
|
||||
|
||||
|
||||
Like token summarization this will use ArchiveEntry and start and end will refer to the entries in the
|
||||
lower level of summarization.
|
||||
|
||||
|
||||
Ts is the iso8601 timestamp of the start of the summarized period.
|
||||
"""
|
||||
|
||||
|
||||
# select the list to use for the entries
|
||||
|
||||
|
||||
if level == 0:
|
||||
entries = self.scene.archived_history
|
||||
else:
|
||||
entries = self.scene.stepped_archived_history[level-1]
|
||||
|
||||
entries = self.scene.stepped_archived_history[level - 1]
|
||||
|
||||
# select the list to summarize new entries to
|
||||
|
||||
|
||||
target = self.scene.stepped_archived_history[level]
|
||||
|
||||
|
||||
if not target:
|
||||
raise ValueError(f"Invalid level {level}")
|
||||
|
||||
|
||||
# determine the start and end of the period to summarize
|
||||
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
|
||||
|
||||
# determine the time threshold for this level
|
||||
|
||||
|
||||
# first calculate all possible thresholds in iso8601 format, starting with 1 day
|
||||
thresholds = [
|
||||
"P1D",
|
||||
@@ -318,61 +339,65 @@ class SummarizeAgent(Agent):
|
||||
"P1M",
|
||||
"P1Y",
|
||||
]
|
||||
|
||||
|
||||
# TODO: auto extend?
|
||||
|
||||
|
||||
time_threshold_in_seconds = util.iso8601_to_seconds(thresholds[level])
|
||||
|
||||
|
||||
if not time_threshold_in_seconds:
|
||||
raise ValueError(f"Invalid level {level}")
|
||||
|
||||
|
||||
# determine the most recent summarized entry time, and then find entries
|
||||
# that are newer than that in the lower list
|
||||
|
||||
|
||||
ts = target[-1].ts if target else entries[0].ts
|
||||
|
||||
|
||||
# determine the most recent entry at the lower level, if its not newer or
|
||||
# the difference is less than the threshold, then we don't need to summarize
|
||||
|
||||
|
||||
recent_entry = entries[-1]
|
||||
|
||||
|
||||
if util.iso8601_diff(recent_entry.ts, ts) < time_threshold_in_seconds:
|
||||
return
|
||||
|
||||
|
||||
log.debug("build_stepped_archive", level=level, ts=ts)
|
||||
|
||||
|
||||
# if target is empty, start is 0
|
||||
# otherwise start is the end of the last entry
|
||||
|
||||
|
||||
start = 0 if not target else target[-1].end
|
||||
|
||||
|
||||
# collect entries starting at start until the combined time duration
|
||||
# exceeds the threshold
|
||||
|
||||
|
||||
entries_to_summarize = []
|
||||
|
||||
|
||||
for entry in entries[start:]:
|
||||
entries_to_summarize.append(entry)
|
||||
if util.iso8601_diff(entry.ts, ts) > time_threshold_in_seconds:
|
||||
break
|
||||
|
||||
|
||||
# summarize the entries
|
||||
# we also collect N entries of previous summaries to use as context
|
||||
|
||||
|
||||
num_previous = self.actions["archive"].config["include_previous"].value
|
||||
if num_previous > 0:
|
||||
extra_context = "\n\n".join([entry["text"] for entry in target[-num_previous:]])
|
||||
extra_context = "\n\n".join(
|
||||
[entry["text"] for entry in target[-num_previous:]]
|
||||
)
|
||||
else:
|
||||
extra_context = None
|
||||
|
||||
|
||||
summarized = await self.summarize(
|
||||
"\n".join(map(str, entries_to_summarize)), extra_context=extra_context
|
||||
)
|
||||
|
||||
|
||||
# push summarized entry to target
|
||||
|
||||
|
||||
ts = entries_to_summarize[-1].ts
|
||||
|
||||
target.append(data_objects.ArchiveEntry(summarized, start, len(entries_to_summarize)-1, ts=ts))
|
||||
|
||||
|
||||
|
||||
target.append(
|
||||
data_objects.ArchiveEntry(
|
||||
summarized, start, len(entries_to_summarize) - 1, ts=ts
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
import asyncio
|
||||
import httpx
|
||||
import base64
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import pydantic
|
||||
import nltk
|
||||
import tempfile
|
||||
import base64
|
||||
import time
|
||||
import uuid
|
||||
import functools
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
import nltk
|
||||
import pydantic
|
||||
import structlog
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
import talemate.config as config
|
||||
@@ -21,91 +24,84 @@ from talemate.emit.signals import handlers
|
||||
from talemate.events import GameLoopNewMessageEvent
|
||||
from talemate.scene_message import CharacterMessage, NarratorMessage
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig
|
||||
from .base import Agent, AgentAction, AgentActionConfig, set_processing
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
|
||||
import time
|
||||
|
||||
try:
|
||||
from TTS.api import TTS
|
||||
except ImportError:
|
||||
TTS = None
|
||||
|
||||
log = structlog.get_logger("talemate.agents.tts")#
|
||||
log = structlog.get_logger("talemate.agents.tts") #
|
||||
|
||||
if not TTS:
|
||||
# TTS installation is massive and requires a lot of dependencies
|
||||
# so we don't want to require it unless the user wants to use it
|
||||
log.info("TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api")
|
||||
log.info(
|
||||
"TTS (local) requires the TTS package, please install with `pip install TTS` if you want to use the local api"
|
||||
)
|
||||
|
||||
|
||||
def parse_chunks(text):
|
||||
|
||||
text = text.replace("...", "__ellipsis__")
|
||||
|
||||
|
||||
chunks = sent_tokenize(text)
|
||||
cleaned_chunks = []
|
||||
|
||||
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*","")
|
||||
chunk = chunk.replace("*", "")
|
||||
if not chunk:
|
||||
continue
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
|
||||
|
||||
for i, chunk in enumerate(cleaned_chunks):
|
||||
chunk = chunk.replace("__ellipsis__", "...")
|
||||
|
||||
cleaned_chunks[i] = chunk
|
||||
|
||||
|
||||
return cleaned_chunks
|
||||
|
||||
def clean_quotes(chunk:str):
|
||||
|
||||
|
||||
def clean_quotes(chunk: str):
|
||||
# if there is an uneven number of quotes, remove the last one if its
|
||||
# at the end of the chunk. If its in the middle, add a quote to the end
|
||||
if chunk.count('"') % 2 == 1:
|
||||
|
||||
if chunk.endswith('"'):
|
||||
chunk = chunk[:-1]
|
||||
else:
|
||||
chunk += '"'
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def rejoin_chunks(chunks:list[str], chunk_size:int=250):
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def rejoin_chunks(chunks: list[str], chunk_size: int = 250):
|
||||
"""
|
||||
Will combine chunks split by punctuation into a single chunk until
|
||||
max chunk size is reached
|
||||
"""
|
||||
|
||||
|
||||
joined_chunks = []
|
||||
|
||||
|
||||
current_chunk = ""
|
||||
|
||||
|
||||
for chunk in chunks:
|
||||
|
||||
if len(current_chunk) + len(chunk) > chunk_size:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
current_chunk = ""
|
||||
|
||||
|
||||
current_chunk += chunk
|
||||
|
||||
|
||||
if current_chunk:
|
||||
joined_chunks.append(clean_quotes(current_chunk))
|
||||
return joined_chunks
|
||||
|
||||
|
||||
class Voice(pydantic.BaseModel):
|
||||
value:str
|
||||
label:str
|
||||
value: str
|
||||
label: str
|
||||
|
||||
|
||||
class VoiceLibrary(pydantic.BaseModel):
|
||||
|
||||
api: str
|
||||
voices: list[Voice] = pydantic.Field(default_factory=list)
|
||||
last_synced: float = None
|
||||
@@ -113,31 +109,30 @@ class VoiceLibrary(pydantic.BaseModel):
|
||||
|
||||
@register()
|
||||
class TTSAgent(Agent):
|
||||
|
||||
|
||||
"""
|
||||
Text to speech agent
|
||||
"""
|
||||
|
||||
|
||||
agent_type = "tts"
|
||||
verbose_name = "Voice"
|
||||
requires_llm_client = False
|
||||
|
||||
|
||||
@classmethod
|
||||
def config_options(cls, agent=None):
|
||||
config_options = super().config_options(agent=agent)
|
||||
|
||||
|
||||
if agent:
|
||||
config_options["actions"]["_config"]["config"]["voice_id"]["choices"] = [
|
||||
voice.model_dump() for voice in agent.list_voices_sync()
|
||||
]
|
||||
|
||||
|
||||
return config_options
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
self.is_enabled = False
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
self.voices = {
|
||||
"elevenlabs": VoiceLibrary(api="elevenlabs"),
|
||||
"coqui": VoiceLibrary(api="coqui"),
|
||||
@@ -147,8 +142,8 @@ class TTSAgent(Agent):
|
||||
self.playback_done_event = asyncio.Event()
|
||||
self.actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
enabled=True,
|
||||
label="Configure",
|
||||
description="TTS agent configuration",
|
||||
config={
|
||||
"api": AgentActionConfig(
|
||||
@@ -169,7 +164,7 @@ class TTSAgent(Agent):
|
||||
value="default",
|
||||
label="Narrator Voice",
|
||||
description="Voice ID/Name to use for TTS",
|
||||
choices=[]
|
||||
choices=[],
|
||||
),
|
||||
"generate_for_player": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -194,55 +189,54 @@ class TTSAgent(Agent):
|
||||
value=False,
|
||||
label="Split generation",
|
||||
description="Generate audio chunks for each sentence - will be much more responsive but may loose context to inform inflection",
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
self.actions["_config"].model_dump()
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def not_ready_reason(self) -> str:
|
||||
"""
|
||||
Returns a string explaining why the agent is not ready
|
||||
"""
|
||||
|
||||
|
||||
if self.ready:
|
||||
return ""
|
||||
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return "TTS not installed"
|
||||
|
||||
|
||||
elif self.requires_token and not self.token:
|
||||
return "No API token"
|
||||
|
||||
|
||||
elif not self.default_voice_id:
|
||||
return "No voice selected"
|
||||
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
suffix = ""
|
||||
|
||||
|
||||
if not self.ready:
|
||||
suffix = f" - {self.not_ready_reason}"
|
||||
else:
|
||||
suffix = f" - {self.voice_id_to_label(self.default_voice_id)}"
|
||||
|
||||
|
||||
api = self.api
|
||||
choices = self.actions["_config"].config["api"].choices
|
||||
api_label = api
|
||||
@@ -250,34 +244,33 @@ class TTSAgent(Agent):
|
||||
if choice["value"] == api:
|
||||
api_label = choice["label"]
|
||||
break
|
||||
|
||||
|
||||
return f"{api_label}{suffix}"
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self.actions["_config"].config["api"].value
|
||||
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
api = self.api
|
||||
return self.config.get(api,{}).get("api_key")
|
||||
|
||||
return self.config.get(api, {}).get("api_key")
|
||||
|
||||
@property
|
||||
def default_voice_id(self):
|
||||
return self.actions["_config"].config["voice_id"].value
|
||||
|
||||
|
||||
@property
|
||||
def requires_token(self):
|
||||
return self.api != "tts"
|
||||
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
|
||||
if self.api == "tts":
|
||||
if not TTS:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
return (not self.requires_token or self.token) and self.default_voice_id
|
||||
|
||||
@property
|
||||
@@ -299,106 +292,118 @@ class TTSAgent(Agent):
|
||||
return 1024
|
||||
elif self.api == "coqui":
|
||||
return 250
|
||||
|
||||
|
||||
return 250
|
||||
|
||||
def apply_config(self, *args, **kwargs):
|
||||
|
||||
try:
|
||||
api = kwargs["actions"]["_config"]["config"]["api"]["value"]
|
||||
except KeyError:
|
||||
api = self.api
|
||||
|
||||
|
||||
api_changed = api != self.api
|
||||
|
||||
log.debug("apply_config", api=api, api_changed=api != self.api, current_api=self.api)
|
||||
|
||||
log.debug(
|
||||
"apply_config", api=api, api_changed=api != self.api, current_api=self.api
|
||||
)
|
||||
|
||||
super().apply_config(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
if api_changed:
|
||||
try:
|
||||
self.actions["_config"].config["voice_id"].value = self.voices[api].voices[0].value
|
||||
self.actions["_config"].config["voice_id"].value = (
|
||||
self.voices[api].voices[0].value
|
||||
)
|
||||
except IndexError:
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop_new_message").connect(self.on_game_loop_new_message)
|
||||
|
||||
talemate.emit.async_signals.get("game_loop_new_message").connect(
|
||||
self.on_game_loop_new_message
|
||||
)
|
||||
|
||||
def on_config_saved(self, event):
|
||||
config = event.data
|
||||
self.config = config
|
||||
instance.emit_agent_status(self.__class__, self)
|
||||
|
||||
async def on_game_loop_new_message(self, emission:GameLoopNewMessageEvent):
|
||||
|
||||
async def on_game_loop_new_message(self, emission: GameLoopNewMessageEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled or not self.ready:
|
||||
return
|
||||
|
||||
|
||||
if not isinstance(emission.message, (CharacterMessage, NarratorMessage)):
|
||||
return
|
||||
|
||||
if isinstance(emission.message, NarratorMessage) and not self.actions["_config"].config["generate_for_narration"].value:
|
||||
|
||||
if (
|
||||
isinstance(emission.message, NarratorMessage)
|
||||
and not self.actions["_config"].config["generate_for_narration"].value
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
|
||||
if emission.message.source == "player" and not self.actions["_config"].config["generate_for_player"].value:
|
||||
if (
|
||||
emission.message.source == "player"
|
||||
and not self.actions["_config"].config["generate_for_player"].value
|
||||
):
|
||||
return
|
||||
elif emission.message.source == "ai" and not self.actions["_config"].config["generate_for_npc"].value:
|
||||
elif (
|
||||
emission.message.source == "ai"
|
||||
and not self.actions["_config"].config["generate_for_npc"].value
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
if isinstance(emission.message, CharacterMessage):
|
||||
character_prefix = emission.message.split(":", 1)[0]
|
||||
else:
|
||||
character_prefix = ""
|
||||
|
||||
log.info("reactive tts", message=emission.message, character_prefix=character_prefix)
|
||||
|
||||
await self.generate(str(emission.message).replace(character_prefix+": ", ""))
|
||||
|
||||
log.info(
|
||||
"reactive tts", message=emission.message, character_prefix=character_prefix
|
||||
)
|
||||
|
||||
def voice(self, voice_id:str) -> Union[Voice, None]:
|
||||
await self.generate(str(emission.message).replace(character_prefix + ": ", ""))
|
||||
|
||||
def voice(self, voice_id: str) -> Union[Voice, None]:
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice
|
||||
return None
|
||||
|
||||
def voice_id_to_label(self, voice_id:str):
|
||||
|
||||
def voice_id_to_label(self, voice_id: str):
|
||||
for voice in self.voices[self.api].voices:
|
||||
if voice.value == voice_id:
|
||||
return voice.label
|
||||
return None
|
||||
|
||||
|
||||
def list_voices_sync(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(self.list_voices())
|
||||
|
||||
|
||||
async def list_voices(self):
|
||||
if self.requires_token and not self.token:
|
||||
return []
|
||||
|
||||
|
||||
library = self.voices[self.api]
|
||||
|
||||
|
||||
# TODO: allow re-syncing voices
|
||||
if library.last_synced:
|
||||
return library.voices
|
||||
|
||||
|
||||
list_fn = getattr(self, f"_list_voices_{self.api}")
|
||||
log.info("Listing voices", api=self.api)
|
||||
|
||||
|
||||
library.voices = await list_fn()
|
||||
library.last_synced = time.time()
|
||||
|
||||
|
||||
# if the current voice cannot be found, reset it
|
||||
if not self.voice(self.default_voice_id):
|
||||
self.actions["_config"].config["voice_id"].value = ""
|
||||
|
||||
|
||||
# set loading to false
|
||||
return library.voices
|
||||
|
||||
@@ -407,11 +412,10 @@ class TTSAgent(Agent):
|
||||
if not self.enabled or not self.ready or not text:
|
||||
return
|
||||
|
||||
|
||||
self.playback_done_event.set()
|
||||
|
||||
|
||||
generate_fn = getattr(self, f"_generate_{self.api}")
|
||||
|
||||
|
||||
if self.actions["_config"].config["generate_chunks"].value:
|
||||
chunks = parse_chunks(text)
|
||||
chunks = rejoin_chunks(chunks)
|
||||
@@ -427,59 +431,71 @@ class TTSAgent(Agent):
|
||||
|
||||
async def generate_chunks(self, generate_fn, chunks):
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace("*","").strip()
|
||||
chunk = chunk.replace("*", "").strip()
|
||||
log.info("Generating audio", api=self.api, chunk=chunk)
|
||||
audio_data = await generate_fn(chunk)
|
||||
self.play_audio(audio_data)
|
||||
|
||||
def play_audio(self, audio_data):
|
||||
# play audio through the python audio player
|
||||
#play(audio_data)
|
||||
|
||||
emit("audio_queue", data={"audio_data": base64.b64encode(audio_data).decode("utf-8")})
|
||||
|
||||
# play(audio_data)
|
||||
|
||||
emit(
|
||||
"audio_queue",
|
||||
data={"audio_data": base64.b64encode(audio_data).decode("utf-8")},
|
||||
)
|
||||
|
||||
self.playback_done_event.set() # Signal that playback is finished
|
||||
|
||||
# LOCAL
|
||||
|
||||
|
||||
async def _generate_tts(self, text: str) -> Union[bytes, None]:
|
||||
|
||||
if not TTS:
|
||||
return
|
||||
|
||||
tts_config = self.config.get("tts",{})
|
||||
|
||||
tts_config = self.config.get("tts", {})
|
||||
model = tts_config.get("model")
|
||||
device = tts_config.get("device", "cpu")
|
||||
|
||||
|
||||
log.debug("tts local", model=model, device=device)
|
||||
|
||||
|
||||
if not hasattr(self, "tts_instance"):
|
||||
self.tts_instance = TTS(model).to(device)
|
||||
|
||||
|
||||
tts = self.tts_instance
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
voice = self.voice(self.default_voice_id)
|
||||
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
file_path = os.path.join(temp_dir, f"tts-{uuid.uuid4()}.wav")
|
||||
|
||||
await loop.run_in_executor(None, functools.partial(tts.tts_to_file, text=text, speaker_wav=voice.value, language="en", file_path=file_path))
|
||||
#tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
|
||||
|
||||
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.tts_to_file,
|
||||
text=text,
|
||||
speaker_wav=voice.value,
|
||||
language="en",
|
||||
file_path=file_path,
|
||||
),
|
||||
)
|
||||
# tts.tts_to_file(text=text, speaker_wav=voice.value, language="en", file_path=file_path)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
|
||||
async def _list_voices_tts(self) -> dict[str, str]:
|
||||
return [Voice(**voice) for voice in self.config.get("tts",{}).get("voices",[])]
|
||||
|
||||
return [
|
||||
Voice(**voice) for voice in self.config.get("tts", {}).get("voices", [])
|
||||
]
|
||||
|
||||
# ELEVENLABS
|
||||
|
||||
async def _generate_elevenlabs(self, text: str, chunk_size: int = 1024) -> Union[bytes, None]:
|
||||
async def _generate_elevenlabs(
|
||||
self, text: str, chunk_size: int = 1024
|
||||
) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
return
|
||||
@@ -493,11 +509,8 @@ class TTSAgent(Agent):
|
||||
}
|
||||
data = {
|
||||
"text": text,
|
||||
"model_id": self.config.get("elevenlabs",{}).get("model"),
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.5
|
||||
}
|
||||
"model_id": self.config.get("elevenlabs", {}).get("model"),
|
||||
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
||||
}
|
||||
|
||||
response = await client.post(url, json=data, headers=headers, timeout=300)
|
||||
@@ -514,27 +527,33 @@ class TTSAgent(Agent):
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
async def _list_voices_elevenlabs(self) -> dict[str, str]:
|
||||
|
||||
url_voices = "https://api.elevenlabs.io/v1/voices"
|
||||
|
||||
|
||||
voices = []
|
||||
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"xi-api-key": self.token,
|
||||
}
|
||||
response = await client.get(url_voices, headers=headers, params={"per_page":1000})
|
||||
response = await client.get(
|
||||
url_voices, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
speakers = response.json()["voices"]
|
||||
voices.extend([Voice(value=speaker["voice_id"], label=speaker["name"]) for speaker in speakers])
|
||||
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=speaker["voice_id"], label=speaker["name"])
|
||||
for speaker in speakers
|
||||
]
|
||||
)
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
|
||||
|
||||
return voices
|
||||
|
||||
# COQUI STUDIO
|
||||
|
||||
|
||||
async def _generate_coqui(self, text: str) -> Union[bytes, None]:
|
||||
api_key = self.token
|
||||
if not api_key:
|
||||
@@ -545,12 +564,12 @@ class TTSAgent(Agent):
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
data = {
|
||||
"voice_id": self.default_voice_id,
|
||||
"text": text,
|
||||
"language": "en" # Assuming English language for simplicity; this could be parameterized
|
||||
"language": "en", # Assuming English language for simplicity; this could be parameterized
|
||||
}
|
||||
|
||||
# Make the POST request to Coqui API
|
||||
@@ -558,7 +577,7 @@ class TTSAgent(Agent):
|
||||
if response.status_code in [200, 201]:
|
||||
# Parse the JSON response to get the audio URL
|
||||
response_data = response.json()
|
||||
audio_url = response_data.get('audio_url')
|
||||
audio_url = response_data.get("audio_url")
|
||||
if audio_url:
|
||||
# Make a GET request to download the audio file
|
||||
audio_response = await client.get(audio_url)
|
||||
@@ -572,7 +591,7 @@ class TTSAgent(Agent):
|
||||
log.error("No audio URL in response")
|
||||
else:
|
||||
log.error(f"Error generating audio: {response.text}")
|
||||
|
||||
|
||||
async def _cleanup_coqui(self, sample_id: str):
|
||||
api_key = self.token
|
||||
if not api_key or not sample_id:
|
||||
@@ -580,9 +599,7 @@ class TTSAgent(Agent):
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = f"https://app.coqui.ai/api/v2/samples/xtts/{sample_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
# Make the DELETE request to Coqui API
|
||||
response = await client.delete(url, headers=headers)
|
||||
@@ -590,28 +607,41 @@ class TTSAgent(Agent):
|
||||
if response.status_code == 204:
|
||||
log.info(f"Successfully deleted sample with ID: {sample_id}")
|
||||
else:
|
||||
log.error(f"Error deleting sample with ID: {sample_id}: {response.text}")
|
||||
log.error(
|
||||
f"Error deleting sample with ID: {sample_id}: {response.text}"
|
||||
)
|
||||
|
||||
async def _list_voices_coqui(self) -> dict[str, str]:
|
||||
|
||||
url_speakers = "https://app.coqui.ai/api/v2/speakers"
|
||||
url_custom_voices = "https://app.coqui.ai/api/v2/voices"
|
||||
|
||||
|
||||
voices = []
|
||||
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}"
|
||||
}
|
||||
response = await client.get(url_speakers, headers=headers, params={"per_page":1000})
|
||||
headers = {"Authorization": f"Bearer {self.token}"}
|
||||
response = await client.get(
|
||||
url_speakers, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
speakers = response.json()["result"]
|
||||
voices.extend([Voice(value=speaker["id"], label=speaker["name"]) for speaker in speakers])
|
||||
|
||||
response = await client.get(url_custom_voices, headers=headers, params={"per_page":1000})
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=speaker["id"], label=speaker["name"])
|
||||
for speaker in speakers
|
||||
]
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
url_custom_voices, headers=headers, params={"per_page": 1000}
|
||||
)
|
||||
custom_voices = response.json()["result"]
|
||||
voices.extend([Voice(value=voice["id"], label=voice["name"]) for voice in custom_voices])
|
||||
|
||||
voices.extend(
|
||||
[
|
||||
Voice(value=voice["id"], label=voice["name"])
|
||||
for voice in custom_voices
|
||||
]
|
||||
)
|
||||
|
||||
# sort by name
|
||||
voices.sort(key=lambda x: x.label)
|
||||
|
||||
return voices
|
||||
|
||||
return voices
|
||||
|
||||
@@ -1,46 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import isodate
|
||||
import structlog
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.world_state import InsertionMode
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage, TimePassageMessage, ReinforcementMessage
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.instance import get_agent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import (
|
||||
DirectorMessage,
|
||||
ReinforcementMessage,
|
||||
TimePassageMessage,
|
||||
)
|
||||
from talemate.world_state import InsertionMode
|
||||
|
||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
|
||||
from .base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
|
||||
from .registry import register
|
||||
|
||||
import structlog
|
||||
import isodate
|
||||
import time
|
||||
|
||||
|
||||
log = structlog.get_logger("talemate.agents.world_state")
|
||||
|
||||
talemate.emit.async_signals.register("agent.world_state.time")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WorldStateAgentEmission(AgentEmission):
|
||||
"""
|
||||
Emission class for world state agent
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TimePassageEmission(WorldStateAgentEmission):
|
||||
"""
|
||||
Emission class for time passage
|
||||
"""
|
||||
|
||||
duration: str
|
||||
narrative: str
|
||||
human_duration: str = None
|
||||
|
||||
|
||||
|
||||
@register()
|
||||
class WorldStateAgent(Agent):
|
||||
@@ -55,26 +63,57 @@ class WorldStateAgent(Agent):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = {
|
||||
"update_world_state": AgentAction(enabled=True, label="Update world state", description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.", config={
|
||||
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before updating the world state.", value=5, min=1, max=100, step=1)
|
||||
}),
|
||||
"update_reinforcements": AgentAction(enabled=True, label="Update state reinforcements", description="Will attempt to update any due state reinforcements.", config={}),
|
||||
"check_pin_conditions": AgentAction(enabled=True, label="Update conditional context pins", description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.", config={
|
||||
"turns": AgentActionConfig(type="number", label="Turns", description="Number of turns to wait before checking conditions.", value=2, min=1, max=100, step=1)
|
||||
}),
|
||||
"update_world_state": AgentAction(
|
||||
enabled=True,
|
||||
label="Update world state",
|
||||
description="Will attempt to update the world state based on the current scene. Runs automatically every N turns.",
|
||||
config={
|
||||
"turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Turns",
|
||||
description="Number of turns to wait before updating the world state.",
|
||||
value=5,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
)
|
||||
},
|
||||
),
|
||||
"update_reinforcements": AgentAction(
|
||||
enabled=True,
|
||||
label="Update state reinforcements",
|
||||
description="Will attempt to update any due state reinforcements.",
|
||||
config={},
|
||||
),
|
||||
"check_pin_conditions": AgentAction(
|
||||
enabled=True,
|
||||
label="Update conditional context pins",
|
||||
description="Will evaluate context pins conditions and toggle those pins accordingly. Runs automatically every N turns.",
|
||||
config={
|
||||
"turns": AgentActionConfig(
|
||||
type="number",
|
||||
label="Turns",
|
||||
description="Number of turns to wait before checking conditions.",
|
||||
value=2,
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
)
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
self.next_update = 0
|
||||
self.next_pin_check = 0
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
|
||||
|
||||
@property
|
||||
def has_toggle(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
@@ -83,110 +122,121 @@ class WorldStateAgent(Agent):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
async def advance_time(self, duration:str, narrative:str=None):
|
||||
async def advance_time(self, duration: str, narrative: str = None):
|
||||
"""
|
||||
Emit a time passage message
|
||||
"""
|
||||
|
||||
|
||||
isodate.parse_duration(duration)
|
||||
human_duration = util.iso8601_duration_to_human(duration, suffix=" later")
|
||||
message = TimePassageMessage(ts=duration, message=human_duration)
|
||||
|
||||
|
||||
log.debug("world_state.advance_time", message=message)
|
||||
self.scene.push_history(message)
|
||||
self.scene.emit_status()
|
||||
|
||||
emit("time", message)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.world_state.time").send(
|
||||
TimePassageEmission(agent=self, duration=duration, narrative=narrative, human_duration=human_duration)
|
||||
)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission:GameLoopEvent):
|
||||
emit("time", message)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.world_state.time").send(
|
||||
TimePassageEmission(
|
||||
agent=self,
|
||||
duration=duration,
|
||||
narrative=narrative,
|
||||
human_duration=human_duration,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
"""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
await self.update_world_state()
|
||||
await self.auto_update_reinforcments()
|
||||
await self.auto_check_pin_conditions()
|
||||
|
||||
|
||||
|
||||
async def auto_update_reinforcments(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if not self.actions["update_reinforcements"].enabled:
|
||||
return
|
||||
|
||||
|
||||
await self.update_reinforcements()
|
||||
|
||||
|
||||
async def auto_check_pin_conditions(self):
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if not self.actions["check_pin_conditions"].enabled:
|
||||
return
|
||||
|
||||
if self.next_pin_check % self.actions["check_pin_conditions"].config["turns"].value != 0 or self.next_pin_check == 0:
|
||||
|
||||
|
||||
if (
|
||||
self.next_pin_check
|
||||
% self.actions["check_pin_conditions"].config["turns"].value
|
||||
!= 0
|
||||
or self.next_pin_check == 0
|
||||
):
|
||||
self.next_pin_check += 1
|
||||
return
|
||||
|
||||
|
||||
self.next_pin_check = 0
|
||||
|
||||
|
||||
await self.check_pin_conditions()
|
||||
|
||||
|
||||
async def update_world_state(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if not self.actions["update_world_state"].enabled:
|
||||
return
|
||||
|
||||
log.debug("update_world_state", next_update=self.next_update, turns=self.actions["update_world_state"].config["turns"].value)
|
||||
|
||||
|
||||
log.debug(
|
||||
"update_world_state",
|
||||
next_update=self.next_update,
|
||||
turns=self.actions["update_world_state"].config["turns"].value,
|
||||
)
|
||||
|
||||
scene = self.scene
|
||||
|
||||
if self.next_update % self.actions["update_world_state"].config["turns"].value != 0 or self.next_update == 0:
|
||||
|
||||
if (
|
||||
self.next_update % self.actions["update_world_state"].config["turns"].value
|
||||
!= 0
|
||||
or self.next_update == 0
|
||||
):
|
||||
self.next_update += 1
|
||||
return
|
||||
|
||||
|
||||
self.next_update = 0
|
||||
await scene.world_state.request_update()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def request_world_state(self):
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
_, world_state = await Prompt.request(
|
||||
"world_state.request-world-state-v2",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"object_type": "character",
|
||||
"object_type_plural": "characters",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.debug("request_world_state", response=world_state, time=time.time() - t1)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"request_world_state", response=world_state, time=time.time() - t1
|
||||
)
|
||||
|
||||
return world_state
|
||||
|
||||
|
||||
@set_processing
|
||||
async def request_world_state_inline(self):
|
||||
|
||||
"""
|
||||
EXPERIMENTAL, Overall the one shot request seems about as coherent as the inline request, but the inline request is is about twice as slow and would need to run on every dialogue line.
|
||||
"""
|
||||
@@ -199,14 +249,18 @@ class WorldStateAgent(Agent):
|
||||
"world_state.request-world-state-inline-items",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.scene.log.debug("request_world_state_inline", marked_items=marked_items_response, time=time.time() - t1)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"request_world_state_inline",
|
||||
marked_items=marked_items_response,
|
||||
time=time.time() - t1,
|
||||
)
|
||||
|
||||
return marked_items_response
|
||||
|
||||
@set_processing
|
||||
@@ -214,99 +268,107 @@ class WorldStateAgent(Agent):
|
||||
self,
|
||||
text: str,
|
||||
):
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-time-passage",
|
||||
self.client,
|
||||
"analyze_freeform_short",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
duration = response.split("\n")[0].split(" ")[0].strip()
|
||||
|
||||
|
||||
if not duration.startswith("P"):
|
||||
duration = "P"+duration
|
||||
|
||||
duration = "P" + duration
|
||||
|
||||
return duration
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def analyze_text_and_extract_context(
|
||||
self,
|
||||
text: str,
|
||||
goal: str,
|
||||
):
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-extract-context",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"goal": goal,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("analyze_text_and_extract_context", goal=goal, text=text, response=response)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_text_and_extract_context", goal=goal, text=text, response=response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def analyze_text_and_extract_context_via_queries(
|
||||
self,
|
||||
text: str,
|
||||
goal: str,
|
||||
) -> list[str]:
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-generate-rag-queries",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"goal": goal,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
queries = response.split("\n")
|
||||
|
||||
|
||||
memory_agent = get_agent("memory")
|
||||
|
||||
|
||||
context = await memory_agent.multi_query(queries, iterate=3)
|
||||
|
||||
log.debug("analyze_text_and_extract_context_via_queries", goal=goal, text=text, queries=queries, context=context)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_text_and_extract_context_via_queries",
|
||||
goal=goal,
|
||||
text=text,
|
||||
queries=queries,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@set_processing
|
||||
async def analyze_and_follow_instruction(
|
||||
self,
|
||||
text: str,
|
||||
instruction: str,
|
||||
):
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-follow-instruction",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"instruction": instruction,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("analyze_and_follow_instruction", instruction=instruction, text=text, response=response)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_and_follow_instruction",
|
||||
instruction=instruction,
|
||||
text=text,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@@ -315,50 +377,52 @@ class WorldStateAgent(Agent):
|
||||
text: str,
|
||||
query: str,
|
||||
):
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-text-and-answer-question",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"query": query,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
log.debug("analyze_text_and_answer_question", query=query, text=text, response=response)
|
||||
|
||||
|
||||
log.debug(
|
||||
"analyze_text_and_answer_question",
|
||||
query=query,
|
||||
text=text,
|
||||
response=response,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def identify_characters(
|
||||
self,
|
||||
text: str = None,
|
||||
):
|
||||
|
||||
"""
|
||||
Attempts to identify characters in the given text.
|
||||
"""
|
||||
|
||||
|
||||
_, data = await Prompt.request(
|
||||
"world_state.identify-characters",
|
||||
self.client,
|
||||
"analyze",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.debug("identify_characters", text=text, data=data)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _parse_character_sheet(self, response):
|
||||
|
||||
data = {}
|
||||
for line in response.split("\n"):
|
||||
if not line.strip():
|
||||
@@ -367,128 +431,131 @@ class WorldStateAgent(Agent):
|
||||
break
|
||||
name, value = line.split(":", 1)
|
||||
data[name.strip()] = value.strip()
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@set_processing
|
||||
async def extract_character_sheet(
|
||||
self,
|
||||
name:str,
|
||||
text:str = None,
|
||||
name: str,
|
||||
text: str = None,
|
||||
):
|
||||
|
||||
"""
|
||||
Attempts to extract a character sheet from the given text.
|
||||
"""
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.extract-character-sheet",
|
||||
self.client,
|
||||
"create",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"text": text,
|
||||
"name": name,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# loop through each line in response and if it contains a : then extract
|
||||
# the left side as an attribute name and the right side as the value
|
||||
#
|
||||
# break as soon as a non-empty line is found that doesn't contain a :
|
||||
|
||||
|
||||
return self._parse_character_sheet(response)
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def match_character_names(self, names:list[str]):
|
||||
|
||||
async def match_character_names(self, names: list[str]):
|
||||
"""
|
||||
Attempts to match character names.
|
||||
"""
|
||||
|
||||
|
||||
_, response = await Prompt.request(
|
||||
"world_state.match-character-names",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"names": names,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
log.debug("match_character_names", names=names, response=response)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def update_reinforcements(self, force:bool=False):
|
||||
|
||||
async def update_reinforcements(self, force: bool = False):
|
||||
"""
|
||||
Queries due worldstate re-inforcements
|
||||
"""
|
||||
|
||||
|
||||
for reinforcement in self.scene.world_state.reinforce:
|
||||
if reinforcement.due <= 0 or force:
|
||||
await self.update_reinforcement(reinforcement.question, reinforcement.character)
|
||||
await self.update_reinforcement(
|
||||
reinforcement.question, reinforcement.character
|
||||
)
|
||||
else:
|
||||
reinforcement.due -= 1
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def update_reinforcement(self, question:str, character:str=None, reset:bool=False):
|
||||
|
||||
async def update_reinforcement(
|
||||
self, question: str, character: str = None, reset: bool = False
|
||||
):
|
||||
"""
|
||||
Queries a single re-inforcement
|
||||
"""
|
||||
message = None
|
||||
idx, reinforcement = await self.scene.world_state.find_reinforcement(question, character)
|
||||
|
||||
idx, reinforcement = await self.scene.world_state.find_reinforcement(
|
||||
question, character
|
||||
)
|
||||
|
||||
if not reinforcement:
|
||||
return
|
||||
|
||||
|
||||
source = f"{reinforcement.question}:{reinforcement.character if reinforcement.character else ''}"
|
||||
|
||||
|
||||
if reset and reinforcement.insert == "sequential":
|
||||
self.scene.pop_history(typ="reinforcement", source=source, all=True)
|
||||
|
||||
|
||||
answer = await Prompt.request(
|
||||
"world_state.update-reinforcements",
|
||||
self.client,
|
||||
"analyze_freeform",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"question": reinforcement.question,
|
||||
"instructions": reinforcement.instructions or "",
|
||||
"character": self.scene.get_character(reinforcement.character) if reinforcement.character else None,
|
||||
"character": self.scene.get_character(reinforcement.character)
|
||||
if reinforcement.character
|
||||
else None,
|
||||
"answer": (reinforcement.answer if not reset else None) or "",
|
||||
"reinforcement": reinforcement,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
reinforcement.answer = answer
|
||||
reinforcement.due = reinforcement.interval
|
||||
|
||||
|
||||
# remove any recent previous reinforcement message with same question
|
||||
# to avoid overloading the near history with reinforcement messages
|
||||
if not reset:
|
||||
self.scene.pop_history(typ="reinforcement", source=source, max_iterations=10)
|
||||
|
||||
self.scene.pop_history(
|
||||
typ="reinforcement", source=source, max_iterations=10
|
||||
)
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
# insert the reinforcement message at the current position
|
||||
message = ReinforcementMessage(message=answer, source=source)
|
||||
log.debug("update_reinforcement", message=message, reset=reset)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
# if reinforcement has a character name set, update the character detail
|
||||
if reinforcement.character:
|
||||
character = self.scene.get_character(reinforcement.character)
|
||||
await character.set_detail(reinforcement.question, answer)
|
||||
|
||||
|
||||
else:
|
||||
# set world entry
|
||||
await self.scene.world_state_manager.save_world_entry(
|
||||
@@ -496,20 +563,19 @@ class WorldStateAgent(Agent):
|
||||
reinforcement.as_context_line,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
self.scene.world_state.emit()
|
||||
|
||||
return message
|
||||
|
||||
|
||||
return message
|
||||
|
||||
@set_processing
|
||||
async def check_pin_conditions(
|
||||
self,
|
||||
):
|
||||
|
||||
"""
|
||||
Checks if any context pin conditions
|
||||
"""
|
||||
|
||||
|
||||
pins_with_condition = {
|
||||
entry_id: {
|
||||
"condition": pin.condition,
|
||||
@@ -518,41 +584,47 @@ class WorldStateAgent(Agent):
|
||||
for entry_id, pin in self.scene.world_state.pins.items()
|
||||
if pin.condition
|
||||
}
|
||||
|
||||
|
||||
if not pins_with_condition:
|
||||
return
|
||||
|
||||
|
||||
first_entry_id = list(pins_with_condition.keys())[0]
|
||||
|
||||
|
||||
_, answers = await Prompt.request(
|
||||
"world_state.check-pin-conditions",
|
||||
self.client,
|
||||
"analyze",
|
||||
vars = {
|
||||
vars={
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"previous_states": json.dumps(pins_with_condition,indent=2),
|
||||
"coercion": {first_entry_id:{ "condition": "" }},
|
||||
}
|
||||
"previous_states": json.dumps(pins_with_condition, indent=2),
|
||||
"coercion": {first_entry_id: {"condition": ""}},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
world_state = self.scene.world_state
|
||||
state_change = False
|
||||
|
||||
state_change = False
|
||||
|
||||
for entry_id, answer in answers.items():
|
||||
|
||||
if entry_id not in world_state.pins:
|
||||
log.warning("check_pin_conditions", entry_id=entry_id, answer=answer, msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)")
|
||||
log.warning(
|
||||
"check_pin_conditions",
|
||||
entry_id=entry_id,
|
||||
answer=answer,
|
||||
msg="entry_id not found in world_state.pins (LLM failed to produce a clean response)",
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
log.info("check_pin_conditions", entry_id=entry_id, answer=answer)
|
||||
state = answer.get("state")
|
||||
if state is True or (isinstance(state, str) and state.lower() in ["true", "yes", "y"]):
|
||||
if state is True or (
|
||||
isinstance(state, str) and state.lower() in ["true", "yes", "y"]
|
||||
):
|
||||
prev_state = world_state.pins[entry_id].condition_state
|
||||
|
||||
|
||||
world_state.pins[entry_id].condition_state = True
|
||||
world_state.pins[entry_id].active = True
|
||||
|
||||
|
||||
if prev_state != world_state.pins[entry_id].condition_state:
|
||||
state_change = True
|
||||
else:
|
||||
@@ -560,49 +632,50 @@ class WorldStateAgent(Agent):
|
||||
world_state.pins[entry_id].condition_state = False
|
||||
world_state.pins[entry_id].active = False
|
||||
state_change = True
|
||||
|
||||
|
||||
if state_change:
|
||||
await self.scene.load_active_pins()
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize_and_pin(self, message_id:int, num_messages:int=3) -> str:
|
||||
|
||||
async def summarize_and_pin(self, message_id: int, num_messages: int = 3) -> str:
|
||||
"""
|
||||
Will take a message index and then walk back N messages
|
||||
summarizing the scene and pinning it to the context.
|
||||
"""
|
||||
|
||||
|
||||
creator = get_agent("creator")
|
||||
summarizer = get_agent("summarizer")
|
||||
|
||||
|
||||
message_index = self.scene.message_index(message_id)
|
||||
|
||||
|
||||
text = self.scene.snapshot(lines=num_messages, start=message_index)
|
||||
|
||||
extra_context = self.scene.snapshot(lines=50, start=message_index-num_messages)
|
||||
|
||||
|
||||
extra_context = self.scene.snapshot(
|
||||
lines=50, start=message_index - num_messages
|
||||
)
|
||||
|
||||
summary = await summarizer.summarize(
|
||||
text,
|
||||
text,
|
||||
extra_context=extra_context,
|
||||
method="short",
|
||||
extra_instructions="Pay particularly close attention to decisions, agreements or promises made.",
|
||||
)
|
||||
|
||||
|
||||
entry_id = util.clean_id(await creator.generate_title(summary))
|
||||
|
||||
|
||||
ts = self.scene.ts
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarize_and_pin",
|
||||
message_id=message_id,
|
||||
message_index=message_index,
|
||||
num_messages=num_messages,
|
||||
num_messages=num_messages,
|
||||
summary=summary,
|
||||
entry_id=entry_id,
|
||||
ts=ts,
|
||||
)
|
||||
|
||||
|
||||
await self.scene.world_state_manager.save_world_entry(
|
||||
entry_id,
|
||||
summary,
|
||||
@@ -610,49 +683,49 @@ class WorldStateAgent(Agent):
|
||||
"ts": ts,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
await self.scene.world_state_manager.set_pin(
|
||||
entry_id,
|
||||
active=True,
|
||||
)
|
||||
|
||||
|
||||
await self.scene.load_active_pins()
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def is_character_present(self, character:str) -> bool:
|
||||
async def is_character_present(self, character: str) -> bool:
|
||||
"""
|
||||
Check if a character is present in the scene
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
- `character`: The character to check.
|
||||
"""
|
||||
|
||||
if len(self.scene.history) < 10:
|
||||
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
|
||||
else:
|
||||
text = self.scene.snapshot(lines=50)
|
||||
|
||||
is_present = await self.analyze_text_and_answer_question(
|
||||
text=text,
|
||||
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
|
||||
)
|
||||
|
||||
return is_present.lower().startswith("y")
|
||||
|
||||
@set_processing
|
||||
async def is_character_leaving(self, character:str) -> bool:
|
||||
"""
|
||||
Check if a character is leaving the scene
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- `character`: The character to check.
|
||||
"""
|
||||
|
||||
if len(self.scene.history) < 10:
|
||||
text = self.scene.intro+"\n\n"+self.scene.snapshot(lines=50)
|
||||
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
|
||||
else:
|
||||
text = self.scene.snapshot(lines=50)
|
||||
|
||||
is_present = await self.analyze_text_and_answer_question(
|
||||
text=text,
|
||||
query=f"Is {character} present AND active in the current scene? Answert with 'yes' or 'no'.",
|
||||
)
|
||||
|
||||
return is_present.lower().startswith("y")
|
||||
|
||||
@set_processing
|
||||
async def is_character_leaving(self, character: str) -> bool:
|
||||
"""
|
||||
Check if a character is leaving the scene
|
||||
|
||||
Arguments:
|
||||
|
||||
- `character`: The character to check.
|
||||
"""
|
||||
|
||||
if len(self.scene.history) < 10:
|
||||
text = self.scene.intro + "\n\n" + self.scene.snapshot(lines=50)
|
||||
else:
|
||||
text = self.scene.snapshot(lines=50)
|
||||
|
||||
@@ -660,5 +733,5 @@ class WorldStateAgent(Agent):
|
||||
text=text,
|
||||
query=f"Is {character} leaving the current scene? Answert with 'yes' or 'no'.",
|
||||
)
|
||||
|
||||
return is_leaving.lower().startswith("y")
|
||||
|
||||
return is_leaving.lower().startswith("y")
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Scene
|
||||
|
||||
|
||||
import structlog
|
||||
|
||||
__all__ = ["AutomatedAction", "register", "initialize_for_scene"]
|
||||
@@ -13,50 +14,64 @@ log = structlog.get_logger("talemate.automated_action")
|
||||
|
||||
AUTOMATED_ACTIONS = {}
|
||||
|
||||
def initialize_for_scene(scene:Scene):
|
||||
|
||||
|
||||
def initialize_for_scene(scene: Scene):
|
||||
for uid, config in AUTOMATED_ACTIONS.items():
|
||||
scene.automated_actions[uid] = config.cls(
|
||||
scene,
|
||||
uid=uid,
|
||||
frequency=config.frequency,
|
||||
call_initially=config.call_initially,
|
||||
enabled=config.enabled
|
||||
enabled=config.enabled,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutomatedActionConfig:
|
||||
uid:str
|
||||
cls:AutomatedAction
|
||||
frequency:int=5
|
||||
call_initially:bool=False
|
||||
enabled:bool=True
|
||||
|
||||
uid: str
|
||||
cls: AutomatedAction
|
||||
frequency: int = 5
|
||||
call_initially: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class register:
|
||||
|
||||
def __init__(self, uid:str, frequency:int=5, call_initially:bool=False, enabled:bool=True):
|
||||
def __init__(
|
||||
self,
|
||||
uid: str,
|
||||
frequency: int = 5,
|
||||
call_initially: bool = False,
|
||||
enabled: bool = True,
|
||||
):
|
||||
self.uid = uid
|
||||
self.frequency = frequency
|
||||
self.call_initially = call_initially
|
||||
self.enabled = enabled
|
||||
|
||||
def __call__(self, action:AutomatedAction):
|
||||
|
||||
def __call__(self, action: AutomatedAction):
|
||||
AUTOMATED_ACTIONS[self.uid] = AutomatedActionConfig(
|
||||
self.uid,
|
||||
action,
|
||||
frequency=self.frequency,
|
||||
call_initially=self.call_initially,
|
||||
enabled=self.enabled
|
||||
self.uid,
|
||||
action,
|
||||
frequency=self.frequency,
|
||||
call_initially=self.call_initially,
|
||||
enabled=self.enabled,
|
||||
)
|
||||
return action
|
||||
|
||||
|
||||
|
||||
class AutomatedAction:
|
||||
"""
|
||||
An action that will be executed every n turns
|
||||
"""
|
||||
|
||||
def __init__(self, scene:Scene, frequency:int=5, call_initially:bool=False, uid:str=None, enabled:bool=True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scene: Scene,
|
||||
frequency: int = 5,
|
||||
call_initially: bool = False,
|
||||
uid: str = None,
|
||||
enabled: bool = True,
|
||||
):
|
||||
self.scene = scene
|
||||
self.enabled = enabled
|
||||
self.frequency = frequency
|
||||
@@ -64,14 +79,19 @@ class AutomatedAction:
|
||||
self.uid = uid
|
||||
if call_initially:
|
||||
self.turns = frequency
|
||||
|
||||
|
||||
async def __call__(self):
|
||||
|
||||
log.debug("automated_action", uid=self.uid, enabled=self.enabled, frequency=self.frequency, turns=self.turns)
|
||||
|
||||
log.debug(
|
||||
"automated_action",
|
||||
uid=self.uid,
|
||||
enabled=self.enabled,
|
||||
frequency=self.frequency,
|
||||
turns=self.turns,
|
||||
)
|
||||
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
|
||||
if self.turns % self.frequency == 0:
|
||||
result = await self.action()
|
||||
log.debug("automated_action", result=result)
|
||||
@@ -79,10 +99,9 @@ class AutomatedAction:
|
||||
# action could not be performed at this turn, we will try again next turn
|
||||
return False
|
||||
self.turns += 1
|
||||
|
||||
|
||||
|
||||
async def action(self) -> Any:
|
||||
"""
|
||||
Override this method to implement your action.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
from typing import Union, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from talemate.instance import get_agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene, Character, Actor
|
||||
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
|
||||
|
||||
__all__ = [
|
||||
"deactivate_character",
|
||||
"activate_character",
|
||||
]
|
||||
|
||||
async def deactivate_character(scene:"Scene", character:Union[str, "Character"]):
|
||||
|
||||
async def deactivate_character(scene: "Scene", character: Union[str, "Character"]):
|
||||
"""
|
||||
Deactivates a character
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- `scene`: The scene to deactivate the character from
|
||||
- `character`: The character to deactivate. Can be a string (the character's name) or a Character object
|
||||
"""
|
||||
|
||||
|
||||
if isinstance(character, str):
|
||||
character = scene.get_character(character)
|
||||
|
||||
|
||||
if character.is_player:
|
||||
# can't deactivate the player
|
||||
return False
|
||||
|
||||
|
||||
if character.name in scene.inactive_characters:
|
||||
# already deactivated
|
||||
return False
|
||||
@@ -34,24 +36,24 @@ async def deactivate_character(scene:"Scene", character:Union[str, "Character"])
|
||||
await scene.remove_actor(character.actor)
|
||||
scene.inactive_characters[character.name] = character
|
||||
|
||||
async def activate_character(scene:"Scene", character:Union[str, "Character"]):
|
||||
|
||||
async def activate_character(scene: "Scene", character: Union[str, "Character"]):
|
||||
"""
|
||||
Activates a character
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- `scene`: The scene to activate the character in
|
||||
- `character`: The character to activate. Can be a string (the character's name) or a Character object
|
||||
"""
|
||||
|
||||
|
||||
if isinstance(character, str):
|
||||
character = scene.get_character(character)
|
||||
|
||||
|
||||
if character.name not in scene.inactive_characters:
|
||||
# already activated
|
||||
return False
|
||||
|
||||
|
||||
actor = scene.Actor(character, get_agent("conversation"))
|
||||
await scene.add_actor(actor)
|
||||
del scene.inactive_characters[character.name]
|
||||
|
||||
|
||||
@@ -2,15 +2,13 @@ import argparse
|
||||
import asyncio
|
||||
import glob
|
||||
import os
|
||||
import structlog
|
||||
|
||||
import structlog
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate import Actor, Character, Helper, Player, Scene
|
||||
from talemate.agents import (
|
||||
ConversationAgent,
|
||||
)
|
||||
from talemate.agents import ConversationAgent
|
||||
from talemate.client import OpenAIClient, TextGeneratorWebuiClient
|
||||
from talemate.emit.console import Console
|
||||
from talemate.load import (
|
||||
@@ -129,7 +127,6 @@ async def run_console_session(parser, args):
|
||||
default_client = None
|
||||
|
||||
if "textgenwebui" in clients.values() or args.client == "textgenwebui":
|
||||
|
||||
# Init the TextGeneratorWebuiClient with ConversationAgent and create an actor
|
||||
textgenwebui_api_url = args.textgenwebui_url
|
||||
|
||||
@@ -145,7 +142,6 @@ async def run_console_session(parser, args):
|
||||
clients[client_name] = text_generator_webui_client
|
||||
|
||||
if "openai" in clients.values() or args.client == "openai":
|
||||
|
||||
openai_client = OpenAIClient()
|
||||
|
||||
for client_name, client_typ in clients.items():
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
|
||||
import talemate.client.runpod
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
import talemate.client.runpod
|
||||
|
||||
@@ -1,26 +1,28 @@
|
||||
"""
|
||||
A unified client base, based on the openai API
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import pydantic
|
||||
from typing import Callable, Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import logging
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.emit import emit
|
||||
import talemate.instance as instance
|
||||
import talemate.client.presets as presets
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.agents.context import active_agent
|
||||
from talemate.client.context import client_context_attribute
|
||||
from talemate.client.model_prompts import model_prompt
|
||||
from talemate.agents.context import active_agent
|
||||
from talemate.emit import emit
|
||||
|
||||
# Set up logging level for httpx to WARNING to suppress debug logs.
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
log = structlog.get_logger("client.base")
|
||||
|
||||
REMOTE_SERVICES = [
|
||||
# TODO: runpod.py should add this to the list
|
||||
@@ -29,22 +31,24 @@ REMOTE_SERVICES = [
|
||||
|
||||
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
|
||||
|
||||
|
||||
class ErrorAction(pydantic.BaseModel):
|
||||
title:str
|
||||
action_name:str
|
||||
icon:str = "mdi-error"
|
||||
arguments:list = []
|
||||
title: str
|
||||
action_name: str
|
||||
icon: str = "mdi-error"
|
||||
arguments: list = []
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url:str = "http://localhost:5000"
|
||||
max_token_length:int = 4096
|
||||
api_url: str = "http://localhost:5000"
|
||||
max_token_length: int = 4096
|
||||
|
||||
|
||||
class ClientBase:
|
||||
|
||||
api_url: str
|
||||
model_name: str
|
||||
api_key: str = None
|
||||
name:str = None
|
||||
name: str = None
|
||||
enabled: bool = True
|
||||
current_status: str = None
|
||||
max_token_length: int = 4096
|
||||
@@ -52,21 +56,21 @@ class ClientBase:
|
||||
connected: bool = False
|
||||
conversation_retries: int = 2
|
||||
auto_break_repetition_enabled: bool = True
|
||||
|
||||
decensor_enabled: bool = True
|
||||
client_type = "base"
|
||||
|
||||
|
||||
class Meta(pydantic.BaseModel):
|
||||
experimental:Union[None,str] = None
|
||||
defaults:Defaults = Defaults()
|
||||
title:str = "Client"
|
||||
name_prefix:str = "Client"
|
||||
experimental: Union[None, str] = None
|
||||
defaults: Defaults = Defaults()
|
||||
title: str = "Client"
|
||||
name_prefix: str = "Client"
|
||||
enable_api_auth: bool = False
|
||||
requires_prompt_template: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str = None,
|
||||
name = None,
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.api_url = api_url
|
||||
@@ -75,121 +79,141 @@ class ClientBase:
|
||||
if "max_token_length" in kwargs:
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
|
||||
|
||||
|
||||
@property
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
|
||||
|
||||
def prompt_template(self, sys_msg, prompt):
|
||||
|
||||
"""
|
||||
Applies the appropriate prompt template for the model.
|
||||
"""
|
||||
|
||||
|
||||
if not self.model_name:
|
||||
self.log.warning("prompt template not applied", reason="no model loaded")
|
||||
return f"{sys_msg}\n{prompt}"
|
||||
|
||||
return model_prompt(self.model_name, sys_msg, prompt)[0]
|
||||
|
||||
|
||||
return model_prompt(self.model_name, sys_msg, prompt)[0]
|
||||
|
||||
def prompt_template_example(self):
|
||||
if not getattr(self, "model_name", None):
|
||||
return None, None
|
||||
return model_prompt(self.model_name, "sysmsg", "prompt<|BOT|>{LLM coercion}")
|
||||
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
|
||||
"""
|
||||
Reconfigures the client.
|
||||
|
||||
|
||||
Keyword Arguments:
|
||||
|
||||
|
||||
- api_url: the API URL to use
|
||||
- max_token_length: the max token length to use
|
||||
- enabled: whether the client is enabled
|
||||
"""
|
||||
|
||||
|
||||
if "api_url" in kwargs:
|
||||
self.api_url = kwargs["api_url"]
|
||||
|
||||
if "max_token_length" in kwargs:
|
||||
if kwargs.get("max_token_length"):
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
|
||||
def toggle_disabled_if_remote(self):
|
||||
|
||||
"""
|
||||
If the client is targeting a remote recognized service, this
|
||||
will disable the client.
|
||||
"""
|
||||
|
||||
|
||||
for service in REMOTE_SERVICES:
|
||||
if service in self.api_url:
|
||||
if self.enabled:
|
||||
self.log.warn("remote service unreachable, disabling client", client=self.name)
|
||||
self.log.warn(
|
||||
"remote service unreachable, disabling client", client=self.name
|
||||
)
|
||||
self.enabled = False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_system_message(self, kind: str) -> str:
|
||||
|
||||
"""
|
||||
Returns the appropriate system message for the given kind of generation
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- kind: the kind of generation
|
||||
"""
|
||||
|
||||
# TODO: make extensible
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE
|
||||
if "analyze_freeform" in kind:
|
||||
return system_prompts.ANALYST_FREEFORM
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE
|
||||
|
||||
|
||||
if self.decensor_enabled:
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE
|
||||
if "analyze_freeform" in kind:
|
||||
return system_prompts.ANALYST_FREEFORM
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE
|
||||
|
||||
else:
|
||||
|
||||
if "narrate" in kind:
|
||||
return system_prompts.NARRATOR_NO_DECENSOR
|
||||
if "story" in kind:
|
||||
return system_prompts.NARRATOR_NO_DECENSOR
|
||||
if "director" in kind:
|
||||
return system_prompts.DIRECTOR_NO_DECENSOR
|
||||
if "create" in kind:
|
||||
return system_prompts.CREATOR_NO_DECENSOR
|
||||
if "roleplay" in kind:
|
||||
return system_prompts.ROLEPLAY_NO_DECENSOR
|
||||
if "conversation" in kind:
|
||||
return system_prompts.ROLEPLAY_NO_DECENSOR
|
||||
if "editor" in kind:
|
||||
return system_prompts.EDITOR_NO_DECENSOR
|
||||
if "world_state" in kind:
|
||||
return system_prompts.WORLD_STATE_NO_DECENSOR
|
||||
if "analyze_freeform" in kind:
|
||||
return system_prompts.ANALYST_FREEFORM_NO_DECENSOR
|
||||
if "analyst" in kind:
|
||||
return system_prompts.ANALYST_NO_DECENSOR
|
||||
if "analyze" in kind:
|
||||
return system_prompts.ANALYST_NO_DECENSOR
|
||||
if "summarize" in kind:
|
||||
return system_prompts.SUMMARIZE_NO_DECENSOR
|
||||
|
||||
return system_prompts.BASIC
|
||||
|
||||
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
|
||||
"""
|
||||
Sets and emits the client status.
|
||||
"""
|
||||
|
||||
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
@@ -205,12 +229,12 @@ class ClientBase:
|
||||
else:
|
||||
model_name = "No model loaded"
|
||||
status = "warning"
|
||||
|
||||
|
||||
status_change = status != self.current_status
|
||||
self.current_status = status
|
||||
|
||||
|
||||
prompt_template_example, prompt_template_file = self.prompt_template_example()
|
||||
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
@@ -220,24 +244,25 @@ class ClientBase:
|
||||
data={
|
||||
"api_key": self.api_key,
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": (prompt_template_file and prompt_template_file != "default.jinja2"),
|
||||
"has_prompt_template": (
|
||||
prompt_template_file and prompt_template_file != "default.jinja2"
|
||||
),
|
||||
"template_file": prompt_template_file,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"error_action": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if status_change:
|
||||
instance.emit_agent_status_by_client(self)
|
||||
|
||||
|
||||
|
||||
async def get_model_name(self):
|
||||
models = await self.client.models.list()
|
||||
try:
|
||||
return models.data[0].id
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
||||
async def status(self):
|
||||
"""
|
||||
Send a request to the API to retrieve the loaded AI model name.
|
||||
@@ -246,12 +271,12 @@ class ClientBase:
|
||||
"""
|
||||
if self.processing:
|
||||
return
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
self.model_name = await self.get_model_name()
|
||||
except Exception as e:
|
||||
@@ -261,62 +286,66 @@ class ClientBase:
|
||||
self.toggle_disabled_if_remote()
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
|
||||
self.connected = True
|
||||
|
||||
|
||||
if not self.model_name or self.model_name == "None":
|
||||
self.log.warning("client model not loaded", client=self)
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
|
||||
self.emit_status()
|
||||
|
||||
|
||||
def generate_prompt_parameters(self, kind:str):
|
||||
|
||||
def generate_prompt_parameters(self, kind: str):
|
||||
parameters = {}
|
||||
self.tune_prompt_parameters(
|
||||
presets.configure(parameters, kind, self.max_token_length),
|
||||
kind
|
||||
presets.configure(parameters, kind, self.max_token_length), kind
|
||||
)
|
||||
return parameters
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
parameters["stream"] = False
|
||||
if client_context_attribute("nuke_repetition") > 0.0 and self.jiggle_enabled_for(kind):
|
||||
self.jiggle_randomness(parameters, offset=client_context_attribute("nuke_repetition"))
|
||||
|
||||
if client_context_attribute(
|
||||
"nuke_repetition"
|
||||
) > 0.0 and self.jiggle_enabled_for(kind):
|
||||
self.jiggle_randomness(
|
||||
parameters, offset=client_context_attribute("nuke_repetition")
|
||||
)
|
||||
|
||||
fn_tune_kind = getattr(self, f"tune_prompt_parameters_{kind}", None)
|
||||
if fn_tune_kind:
|
||||
fn_tune_kind(parameters)
|
||||
|
||||
agent_context = active_agent.get()
|
||||
if agent_context.agent:
|
||||
agent_context.agent.inject_prompt_paramters(parameters, kind, agent_context.action)
|
||||
|
||||
def tune_prompt_parameters_conversation(self, parameters:dict):
|
||||
agent_context.agent.inject_prompt_paramters(
|
||||
parameters, kind, agent_context.action
|
||||
)
|
||||
|
||||
def tune_prompt_parameters_conversation(self, parameters: dict):
|
||||
conversation_context = client_context_attribute("conversation")
|
||||
parameters["max_tokens"] = conversation_context.get("length", 96)
|
||||
|
||||
|
||||
dialog_stopping_strings = [
|
||||
f"{character}:" for character in conversation_context["other_characters"]
|
||||
]
|
||||
|
||||
|
||||
if "extra_stopping_strings" in parameters:
|
||||
parameters["extra_stopping_strings"] += dialog_stopping_strings
|
||||
else:
|
||||
parameters["extra_stopping_strings"] = dialog_stopping_strings
|
||||
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.completions.create(prompt=prompt.strip(" "), **parameters)
|
||||
response = await self.client.completions.create(
|
||||
prompt=prompt.strip(" "), **parameters
|
||||
)
|
||||
return response.get("choices", [{}])[0].get("text", "")
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
@@ -324,85 +353,97 @@ class ClientBase:
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="Error during generation (check logs)", status="error")
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
async def send_prompt(
|
||||
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x, retries:int=2
|
||||
self,
|
||||
prompt: str,
|
||||
kind: str = "conversation",
|
||||
finalize: Callable = lambda x: x,
|
||||
retries: int = 2,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the AI and return its response.
|
||||
:param prompt: The text prompt to send.
|
||||
:return: The AI's response text.
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
self.emit_status(processing=True)
|
||||
await self.status()
|
||||
|
||||
prompt_param = self.generate_prompt_parameters(kind)
|
||||
|
||||
finalized_prompt = self.prompt_template(self.get_system_message(kind), prompt).strip(" ")
|
||||
finalized_prompt = self.prompt_template(
|
||||
self.get_system_message(kind), prompt
|
||||
).strip(" ")
|
||||
prompt_param = finalize(prompt_param)
|
||||
|
||||
token_length = self.count_tokens(finalized_prompt)
|
||||
|
||||
|
||||
time_start = time.time()
|
||||
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
|
||||
|
||||
self.log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length, parameters=prompt_param)
|
||||
response = await self.generate(
|
||||
self.repetition_adjustment(finalized_prompt),
|
||||
prompt_param,
|
||||
kind
|
||||
self.log.debug(
|
||||
"send_prompt",
|
||||
token_length=token_length,
|
||||
max_token_length=self.max_token_length,
|
||||
parameters=prompt_param,
|
||||
)
|
||||
|
||||
response, finalized_prompt = await self.auto_break_repetition(finalized_prompt, prompt_param, response, kind, retries)
|
||||
|
||||
response = await self.generate(
|
||||
self.repetition_adjustment(finalized_prompt), prompt_param, kind
|
||||
)
|
||||
|
||||
response, finalized_prompt = await self.auto_break_repetition(
|
||||
finalized_prompt, prompt_param, response, kind, retries
|
||||
)
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
|
||||
# stopping strings sometimes get appended to the end of the response anyways
|
||||
# split the response by the first stopping string and take the first part
|
||||
|
||||
|
||||
|
||||
for stopping_string in STOPPING_STRINGS + extra_stopping_strings:
|
||||
if stopping_string in response:
|
||||
response = response.split(stopping_string)[0]
|
||||
break
|
||||
|
||||
emit("prompt_sent", data={
|
||||
"kind": kind,
|
||||
"prompt": finalized_prompt,
|
||||
"response": response,
|
||||
"prompt_tokens": token_length,
|
||||
"response_tokens": self.count_tokens(response),
|
||||
"time": time_end - time_start,
|
||||
})
|
||||
|
||||
|
||||
emit(
|
||||
"prompt_sent",
|
||||
data={
|
||||
"kind": kind,
|
||||
"prompt": finalized_prompt,
|
||||
"response": response,
|
||||
"prompt_tokens": token_length,
|
||||
"response_tokens": self.count_tokens(response),
|
||||
"time": time_end - time_start,
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
finally:
|
||||
self.emit_status(processing=False)
|
||||
|
||||
|
||||
|
||||
async def auto_break_repetition(
|
||||
self,
|
||||
finalized_prompt:str,
|
||||
prompt_param:dict,
|
||||
response:str,
|
||||
kind:str,
|
||||
retries:int,
|
||||
pad_max_tokens:int=32,
|
||||
self,
|
||||
finalized_prompt: str,
|
||||
prompt_param: dict,
|
||||
response: str,
|
||||
kind: str,
|
||||
retries: int,
|
||||
pad_max_tokens: int = 32,
|
||||
) -> str:
|
||||
|
||||
"""
|
||||
If repetition breaking is enabled, this will retry the prompt if its
|
||||
response is too similar to other messages in the prompt
|
||||
|
||||
|
||||
This requires the agent to have the allow_repetition_break method
|
||||
and the jiggle_enabled_for method and the client to have the
|
||||
auto_break_repetition_enabled attribute set to True
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
- finalized_prompt: the prompt that was sent
|
||||
@@ -411,47 +452,46 @@ class ClientBase:
|
||||
- kind: the kind of generation
|
||||
- retries: the number of retries left
|
||||
- pad_max_tokens: increase response max_tokens by this amount per iteration
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
|
||||
- the response
|
||||
"""
|
||||
|
||||
|
||||
if not self.auto_break_repetition_enabled:
|
||||
return response, finalized_prompt
|
||||
|
||||
|
||||
agent_context = active_agent.get()
|
||||
if self.jiggle_enabled_for(kind, auto=True):
|
||||
|
||||
# check if the response is a repetition
|
||||
# using the default similarity threshold of 98, meaning it needs
|
||||
# to be really similar to be considered a repetition
|
||||
|
||||
|
||||
is_repetition, similarity_score, matched_line = util.similarity_score(
|
||||
response,
|
||||
finalized_prompt.split("\n"),
|
||||
similarity_threshold=80
|
||||
response, finalized_prompt.split("\n"), similarity_threshold=80
|
||||
)
|
||||
|
||||
|
||||
if not is_repetition:
|
||||
|
||||
# not a repetition, return the response
|
||||
|
||||
self.log.debug("send_prompt no similarity", similarity_score=similarity_score)
|
||||
finalized_prompt = self.repetition_adjustment(finalized_prompt, is_repetitive=False)
|
||||
return response, finalized_prompt
|
||||
|
||||
while is_repetition and retries > 0:
|
||||
|
||||
# it's a repetition, retry the prompt with adjusted parameters
|
||||
|
||||
self.log.warn(
|
||||
"send_prompt similarity retry",
|
||||
agent=agent_context.agent.agent_type,
|
||||
similarity_score=similarity_score,
|
||||
retries=retries
|
||||
|
||||
self.log.debug(
|
||||
"send_prompt no similarity", similarity_score=similarity_score
|
||||
)
|
||||
|
||||
finalized_prompt = self.repetition_adjustment(
|
||||
finalized_prompt, is_repetitive=False
|
||||
)
|
||||
return response, finalized_prompt
|
||||
|
||||
while is_repetition and retries > 0:
|
||||
# it's a repetition, retry the prompt with adjusted parameters
|
||||
|
||||
self.log.warn(
|
||||
"send_prompt similarity retry",
|
||||
agent=agent_context.agent.agent_type,
|
||||
similarity_score=similarity_score,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
# first we apply the client's randomness jiggle which will adjust
|
||||
# parameters like temperature and repetition_penalty, depending
|
||||
# on the client
|
||||
@@ -459,90 +499,94 @@ class ClientBase:
|
||||
# this is a cumulative adjustment, so it will add to the previous
|
||||
# iteration's adjustment, this also means retries should be kept low
|
||||
# otherwise it will get out of hand and start generating nonsense
|
||||
|
||||
|
||||
self.jiggle_randomness(prompt_param, offset=0.5)
|
||||
|
||||
|
||||
# then we pad the max_tokens by the pad_max_tokens amount
|
||||
|
||||
|
||||
prompt_param["max_tokens"] += pad_max_tokens
|
||||
|
||||
|
||||
# send the prompt again
|
||||
# we use the repetition_adjustment method to further encourage
|
||||
# the AI to break the repetition on its own as well.
|
||||
|
||||
finalized_prompt = self.repetition_adjustment(finalized_prompt, is_repetitive=True)
|
||||
|
||||
response = retried_response = await self.generate(
|
||||
finalized_prompt,
|
||||
prompt_param,
|
||||
kind
|
||||
|
||||
finalized_prompt = self.repetition_adjustment(
|
||||
finalized_prompt, is_repetitive=True
|
||||
)
|
||||
|
||||
self.log.debug("send_prompt dedupe sentences", response=response, matched_line=matched_line)
|
||||
|
||||
|
||||
response = retried_response = await self.generate(
|
||||
finalized_prompt, prompt_param, kind
|
||||
)
|
||||
|
||||
self.log.debug(
|
||||
"send_prompt dedupe sentences",
|
||||
response=response,
|
||||
matched_line=matched_line,
|
||||
)
|
||||
|
||||
# a lot of the times the response will now contain the repetition + something new
|
||||
# so we dedupe the response to remove the repetition on sentences level
|
||||
|
||||
response = util.dedupe_sentences(response, matched_line, similarity_threshold=85, debug=True)
|
||||
self.log.debug("send_prompt dedupe sentences (after)", response=response)
|
||||
|
||||
|
||||
response = util.dedupe_sentences(
|
||||
response, matched_line, similarity_threshold=85, debug=True
|
||||
)
|
||||
self.log.debug(
|
||||
"send_prompt dedupe sentences (after)", response=response
|
||||
)
|
||||
|
||||
# deduping may have removed the entire response, so we check for that
|
||||
|
||||
|
||||
if not util.strip_partial_sentences(response).strip():
|
||||
|
||||
# if the response is empty, we set the response to the original
|
||||
# and try again next loop
|
||||
|
||||
|
||||
response = retried_response
|
||||
|
||||
|
||||
# check if the response is a repetition again
|
||||
|
||||
|
||||
is_repetition, similarity_score, matched_line = util.similarity_score(
|
||||
response,
|
||||
finalized_prompt.split("\n"),
|
||||
similarity_threshold=80
|
||||
response, finalized_prompt.split("\n"), similarity_threshold=80
|
||||
)
|
||||
retries -= 1
|
||||
|
||||
|
||||
return response, finalized_prompt
|
||||
|
||||
def count_tokens(self, content:str):
|
||||
|
||||
def count_tokens(self, content: str):
|
||||
return util.count_tokens(content)
|
||||
|
||||
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
min_offset = offset * 0.3
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
|
||||
def jiggle_enabled_for(self, kind:str, auto:bool=False) -> bool:
|
||||
|
||||
|
||||
def jiggle_enabled_for(self, kind: str, auto: bool = False) -> bool:
|
||||
agent_context = active_agent.get()
|
||||
agent = agent_context.agent
|
||||
|
||||
|
||||
if not agent:
|
||||
return False
|
||||
|
||||
|
||||
return agent.allow_repetition_break(kind, agent_context.action, auto=auto)
|
||||
|
||||
def repetition_adjustment(self, prompt:str, is_repetitive:bool=False):
|
||||
|
||||
def repetition_adjustment(self, prompt: str, is_repetitive: bool = False):
|
||||
"""
|
||||
Breaks the prompt into lines and checkse each line for a match with
|
||||
[$REPETITION|{repetition_adjustment}].
|
||||
|
||||
|
||||
On match and if is_repetitive is True, the line is removed from the prompt and
|
||||
replaced with the repetition_adjustment.
|
||||
|
||||
On match and if is_repetitive is False, the line is removed from the prompt.
|
||||
|
||||
On match and if is_repetitive is False, the line is removed from the prompt.
|
||||
"""
|
||||
|
||||
|
||||
lines = prompt.split("\n")
|
||||
new_lines = []
|
||||
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("[$REPETITION|"):
|
||||
if is_repetitive:
|
||||
@@ -551,5 +595,5 @@ class ClientBase:
|
||||
new_lines.append("")
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pydantic
|
||||
from enum import Enum
|
||||
|
||||
import pydantic
|
||||
|
||||
__all__ = [
|
||||
"ClientType",
|
||||
"ClientBootstrap",
|
||||
@@ -10,8 +11,10 @@ __all__ = [
|
||||
|
||||
LISTS = {}
|
||||
|
||||
|
||||
class ClientType(str, Enum):
|
||||
"""Client type enum."""
|
||||
|
||||
textgen = "textgenwebui"
|
||||
automatic1111 = "automatic1111"
|
||||
|
||||
@@ -20,43 +23,42 @@ class ClientBootstrap(pydantic.BaseModel):
|
||||
"""Client bootstrap model."""
|
||||
|
||||
# client type, currently supports "textgen" and "automatic1111"
|
||||
|
||||
|
||||
client_type: ClientType
|
||||
|
||||
|
||||
# unique client identifier
|
||||
|
||||
|
||||
uid: str
|
||||
|
||||
|
||||
# connection name
|
||||
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
# connection information for the client
|
||||
# REST api url
|
||||
|
||||
|
||||
api_url: str
|
||||
|
||||
|
||||
# service name (for example runpod)
|
||||
|
||||
|
||||
service_name: str
|
||||
|
||||
|
||||
|
||||
class register_list:
|
||||
|
||||
def __init__(self, service_name:str):
|
||||
def __init__(self, service_name: str):
|
||||
self.service_name = service_name
|
||||
|
||||
|
||||
def __call__(self, func):
|
||||
LISTS[self.service_name] = func
|
||||
return func
|
||||
|
||||
|
||||
|
||||
async def list_all(exclude_urls: list[str] = list()):
|
||||
"""
|
||||
Return a list of client bootstrap objects.
|
||||
"""
|
||||
|
||||
|
||||
for service_name, func in LISTS.items():
|
||||
async for item in func():
|
||||
if item.api_url not in exclude_urls:
|
||||
yield item.dict()
|
||||
yield item.dict()
|
||||
|
||||
@@ -3,19 +3,20 @@ Context managers for various client-side operations.
|
||||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from pydantic import BaseModel, Field
|
||||
from copy import deepcopy
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = [
|
||||
'context_data',
|
||||
'client_context_attribute',
|
||||
'ContextModel',
|
||||
"context_data",
|
||||
"client_context_attribute",
|
||||
"ContextModel",
|
||||
]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def model_to_dict_without_defaults(model_instance):
|
||||
model_dict = model_instance.dict()
|
||||
for field_name, field in model_instance.__class__.__fields__.items():
|
||||
@@ -23,20 +24,25 @@ def model_to_dict_without_defaults(model_instance):
|
||||
del model_dict[field_name]
|
||||
return model_dict
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
talking_character: str = None
|
||||
other_characters: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ContextModel(BaseModel):
|
||||
"""
|
||||
Pydantic model for the context data.
|
||||
"""
|
||||
|
||||
nuke_repetition: float = Field(0.0, ge=0.0, le=3.0)
|
||||
conversation: ConversationContext = Field(default_factory=ConversationContext)
|
||||
length: int = 96
|
||||
|
||||
|
||||
# Define the context variable as an empty dictionary
|
||||
context_data = ContextVar('context_data', default=ContextModel().model_dump())
|
||||
context_data = ContextVar("context_data", default=ContextModel().model_dump())
|
||||
|
||||
|
||||
def client_context_attribute(name, default=None):
|
||||
"""
|
||||
@@ -47,6 +53,7 @@ def client_context_attribute(name, default=None):
|
||||
# Return the value of the key if it exists, otherwise return the default value
|
||||
return data.get(name, default)
|
||||
|
||||
|
||||
def set_client_context_attribute(name, value):
|
||||
"""
|
||||
Set the value of the context variable `context_data` for the given key.
|
||||
@@ -55,7 +62,8 @@ def set_client_context_attribute(name, value):
|
||||
data = context_data.get()
|
||||
# Set the value of the key
|
||||
data[name] = value
|
||||
|
||||
|
||||
|
||||
def set_conversation_context_attribute(name, value):
|
||||
"""
|
||||
Set the value of the context variable `context_data.conversation` for the given key.
|
||||
@@ -65,6 +73,7 @@ def set_conversation_context_attribute(name, value):
|
||||
# Set the value of the key
|
||||
data["conversation"][name] = value
|
||||
|
||||
|
||||
class ClientContext:
|
||||
"""
|
||||
A context manager to set values to the context variable `context_data`.
|
||||
@@ -82,10 +91,10 @@ class ClientContext:
|
||||
Set the key-value pairs to the context variable `context_data` when entering the context.
|
||||
"""
|
||||
# Get the current context data
|
||||
|
||||
|
||||
data = deepcopy(context_data.get()) if context_data.get() else {}
|
||||
data.update(self.values)
|
||||
|
||||
|
||||
# Update the context data
|
||||
self.token = context_data.set(data)
|
||||
|
||||
@@ -93,5 +102,5 @@ class ClientContext:
|
||||
"""
|
||||
Reset the context variable `context_data` to its previous values when exiting the context.
|
||||
"""
|
||||
|
||||
|
||||
context_data.reset(self.token)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Union
|
||||
|
||||
import requests
|
||||
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
import talemate.util as util
|
||||
from talemate.client.registry import register
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
from talemate.client.textgenwebui import RESTTaleMateClient
|
||||
from talemate.emit import Emission, emit
|
||||
|
||||
# NOT IMPLEMENTED AT THIS POINT
|
||||
# NOT IMPLEMENTED AT THIS POINT
|
||||
|
||||
@@ -1,65 +1,62 @@
|
||||
import pydantic
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url:str = "http://localhost:1234"
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
|
||||
@register()
|
||||
class LMStudioClient(ClientBase):
|
||||
|
||||
client_type = "lmstudio"
|
||||
conversation_retries = 5
|
||||
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix:str = "LMStudio"
|
||||
title:str = "LMStudio"
|
||||
defaults:Defaults = Defaults()
|
||||
name_prefix: str = "LMStudio"
|
||||
title: str = "LMStudio"
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
|
||||
|
||||
async def get_model_name(self):
|
||||
model_name = await super().get_model_name()
|
||||
|
||||
# model name comes back as a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
|
||||
if model_name:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
return ""
|
||||
return ""
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import os
|
||||
import structlog
|
||||
import shutil
|
||||
import huggingface_hub
|
||||
import tempfile
|
||||
|
||||
import huggingface_hub
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
__all__ = ["model_prompt"]
|
||||
|
||||
BASE_TEMPLATE_PATH = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "templates", "llm-prompt"
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"..",
|
||||
"..",
|
||||
"..",
|
||||
"templates",
|
||||
"llm-prompt",
|
||||
)
|
||||
|
||||
# holds the default templates
|
||||
# holds the default templates
|
||||
STD_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "std")
|
||||
|
||||
# llm prompt templates provided by talemate
|
||||
@@ -22,62 +28,73 @@ USER_TEMPLATE_PATH = os.path.join(BASE_TEMPLATE_PATH, "user")
|
||||
|
||||
TEMPLATE_IDENTIFIERS = []
|
||||
|
||||
|
||||
def register_template_identifier(cls):
|
||||
TEMPLATE_IDENTIFIERS.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
log = structlog.get_logger("talemate.model_prompts")
|
||||
|
||||
|
||||
class ModelPrompt:
|
||||
|
||||
|
||||
"""
|
||||
Will attempt to load an LLM prompt template based on the model name
|
||||
|
||||
|
||||
If the model name is not found, it will default to the 'default' template
|
||||
"""
|
||||
|
||||
|
||||
template_map = {}
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
if not hasattr(self, "_env"):
|
||||
log.info("modal prompt", base_template_path=BASE_TEMPLATE_PATH)
|
||||
self._env = Environment(loader=FileSystemLoader([
|
||||
USER_TEMPLATE_PATH,
|
||||
TALEMATE_TEMPLATE_PATH,
|
||||
]))
|
||||
|
||||
self._env = Environment(
|
||||
loader=FileSystemLoader(
|
||||
[
|
||||
USER_TEMPLATE_PATH,
|
||||
TALEMATE_TEMPLATE_PATH,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return self._env
|
||||
|
||||
|
||||
@property
|
||||
def std_templates(self) -> list[str]:
|
||||
env = Environment(loader=FileSystemLoader(STD_TEMPLATE_PATH))
|
||||
return sorted(env.list_templates())
|
||||
|
||||
def __call__(self, model_name:str, system_message:str, prompt:str):
|
||||
|
||||
def __call__(self, model_name: str, system_message: str, prompt: str):
|
||||
template, template_file = self.get_template(model_name)
|
||||
if not template:
|
||||
template_file = "default.jinja2"
|
||||
template = self.env.get_template(template_file)
|
||||
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
user_message, coercion_message = prompt.split("<|BOT|>", 1)
|
||||
else:
|
||||
user_message = prompt
|
||||
coercion_message = ""
|
||||
|
||||
return template.render({
|
||||
"system_message": system_message,
|
||||
"prompt": prompt,
|
||||
"user_message": user_message,
|
||||
"coercion_message": coercion_message,
|
||||
"set_response" : self.set_response
|
||||
}), template_file
|
||||
|
||||
def set_response(self, prompt:str, response_str:str):
|
||||
|
||||
|
||||
return (
|
||||
template.render(
|
||||
{
|
||||
"system_message": system_message,
|
||||
"prompt": prompt,
|
||||
"user_message": user_message,
|
||||
"coercion_message": coercion_message,
|
||||
"set_response": self.set_response,
|
||||
}
|
||||
),
|
||||
template_file,
|
||||
)
|
||||
|
||||
def set_response(self, prompt: str, response_str: str):
|
||||
prompt = prompt.strip("\n").strip()
|
||||
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
if "\n<|BOT|>" in prompt:
|
||||
prompt = prompt.replace("\n<|BOT|>", response_str)
|
||||
@@ -85,17 +102,17 @@ class ModelPrompt:
|
||||
prompt = prompt.replace("<|BOT|>", response_str)
|
||||
else:
|
||||
prompt = prompt.rstrip("\n") + response_str
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
def get_template(self, model_name:str):
|
||||
def get_template(self, model_name: str):
|
||||
"""
|
||||
Will attempt to load an LLM prompt template - this supports
|
||||
partial filename matching on the template file name.
|
||||
"""
|
||||
|
||||
|
||||
matches = []
|
||||
|
||||
|
||||
# Iterate over all templates in the loader's directory
|
||||
for template_name in self.env.list_templates():
|
||||
# strip extension
|
||||
@@ -103,56 +120,58 @@ class ModelPrompt:
|
||||
# Check if the model name is in the template filename
|
||||
if template_name_match.lower() in model_name.lower():
|
||||
matches.append(template_name)
|
||||
|
||||
|
||||
# If there are no matches, return None
|
||||
if not matches:
|
||||
return None, None
|
||||
|
||||
|
||||
# If there is only one match, return it
|
||||
if len(matches) == 1:
|
||||
return self.env.get_template(matches[0]), matches[0]
|
||||
|
||||
|
||||
# If there are multiple matches, return the one with the longest name
|
||||
sorted_matches = sorted(matches, key=lambda x: len(x), reverse=True)
|
||||
return self.env.get_template(sorted_matches[0]), sorted_matches[0]
|
||||
|
||||
|
||||
def create_user_override(self, template_name:str, model_name:str):
|
||||
|
||||
|
||||
def create_user_override(self, template_name: str, model_name: str):
|
||||
"""
|
||||
Will copy STD_TEMPLATE_PATH/template_name to USER_TEMPLATE_PATH/model_name.jinja2
|
||||
"""
|
||||
|
||||
|
||||
template_name = template_name.split(".jinja2")[0]
|
||||
|
||||
|
||||
shutil.copyfile(
|
||||
os.path.join(STD_TEMPLATE_PATH, template_name + ".jinja2"),
|
||||
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
|
||||
os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2"),
|
||||
)
|
||||
|
||||
|
||||
return os.path.join(USER_TEMPLATE_PATH, model_name + ".jinja2")
|
||||
|
||||
def query_hf_for_prompt_template_suggestion(self, model_name:str):
|
||||
|
||||
def query_hf_for_prompt_template_suggestion(self, model_name: str):
|
||||
print("query_hf_for_prompt_template_suggestion", model_name)
|
||||
api = huggingface_hub.HfApi()
|
||||
|
||||
|
||||
try:
|
||||
author, model_name = model_name.split("_", 1)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
models = list(api.list_models(
|
||||
filter=huggingface_hub.ModelFilter(model_name=model_name, author=author)
|
||||
))
|
||||
|
||||
|
||||
models = list(
|
||||
api.list_models(
|
||||
filter=huggingface_hub.ModelFilter(model_name=model_name, author=author)
|
||||
)
|
||||
)
|
||||
|
||||
if not models:
|
||||
return None
|
||||
|
||||
|
||||
model = models[0]
|
||||
|
||||
|
||||
repo_id = f"{author}/{model_name}"
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
readme_path = huggingface_hub.hf_hub_download(repo_id=repo_id, filename="README.md", cache_dir=tmpdir)
|
||||
readme_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=repo_id, filename="README.md", cache_dir=tmpdir
|
||||
)
|
||||
if not readme_path:
|
||||
return None
|
||||
with open(readme_path) as f:
|
||||
@@ -163,24 +182,27 @@ class ModelPrompt:
|
||||
return f"{identifier.template_str}.jinja2"
|
||||
|
||||
|
||||
|
||||
model_prompt = ModelPrompt()
|
||||
|
||||
|
||||
class TemplateIdentifier:
|
||||
def __call__(self, content:str):
|
||||
def __call__(self, content: str):
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class Llama2Identifier(TemplateIdentifier):
|
||||
template_str = "Llama2"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
return "[INST]" in content and "[/INST]" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class ChatMLIdentifier(TemplateIdentifier):
|
||||
template_str = "ChatML"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<|im_start|>system
|
||||
{{ system_message }}<|im_end|>
|
||||
@@ -189,7 +211,7 @@ class ChatMLIdentifier(TemplateIdentifier):
|
||||
<|im_start|>assistant
|
||||
{{ coercion_message }}
|
||||
"""
|
||||
|
||||
|
||||
return (
|
||||
"<|im_start|>system" in content
|
||||
and "<|im_end|>" in content
|
||||
@@ -197,20 +219,24 @@ class ChatMLIdentifier(TemplateIdentifier):
|
||||
and "<|im_start|>assistant" in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class InstructionInputResponseIdentifier(TemplateIdentifier):
|
||||
template_str = "InstructionInputResponse"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
return (
|
||||
"### Instruction:" in content
|
||||
and "### Input:" in content
|
||||
and "### Response:" in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class AlpacaIdentifier(TemplateIdentifier):
|
||||
template_str = "Alpaca"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
{{ system_message }}
|
||||
|
||||
@@ -220,20 +246,19 @@ class AlpacaIdentifier(TemplateIdentifier):
|
||||
### Response:
|
||||
{{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"### Instruction:" in content
|
||||
and "### Response:" in content
|
||||
)
|
||||
|
||||
|
||||
return "### Instruction:" in content and "### Response:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class OpenChatIdentifier(TemplateIdentifier):
|
||||
template_str = "OpenChat"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
GPT4 Correct System: {{ system_message }}<|end_of_turn|>GPT4 Correct User: {{ user_message }}<|end_of_turn|>GPT4 Correct Assistant: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
|
||||
return (
|
||||
"<|end_of_turn|>" in content
|
||||
and "GPT4 Correct System:" in content
|
||||
@@ -241,54 +266,51 @@ class OpenChatIdentifier(TemplateIdentifier):
|
||||
and "GPT4 Correct Assistant:" in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class VicunaIdentifier(TemplateIdentifier):
|
||||
template_str = "Vicuna"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
SYSTEM: {{ system_message }}
|
||||
USER: {{ user_message }}
|
||||
ASSISTANT: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"SYSTEM:" in content
|
||||
and "USER:" in content
|
||||
and "ASSISTANT:" in content
|
||||
)
|
||||
|
||||
|
||||
return "SYSTEM:" in content and "USER:" in content and "ASSISTANT:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class USER_ASSISTANTIdentifier(TemplateIdentifier):
|
||||
template_str = "USER_ASSISTANT"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
USER: {{ system_message }} {{ user_message }} ASSISTANT: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"USER:" in content
|
||||
and "ASSISTANT:" in content
|
||||
)
|
||||
|
||||
|
||||
return "USER:" in content and "ASSISTANT:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class UserAssistantIdentifier(TemplateIdentifier):
|
||||
template_str = "UserAssistant"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
User: {{ system_message }} {{ user_message }}
|
||||
Assistant: {{ coercion_message }}
|
||||
"""
|
||||
|
||||
return (
|
||||
"User:" in content
|
||||
and "Assistant:" in content
|
||||
)
|
||||
|
||||
return "User:" in content and "Assistant:" in content
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
class ZephyrIdentifier(TemplateIdentifier):
|
||||
template_str = "Zephyr"
|
||||
def __call__(self, content:str):
|
||||
|
||||
def __call__(self, content: str):
|
||||
"""
|
||||
<|system|>
|
||||
{{ system_message }}</s>
|
||||
@@ -297,9 +319,9 @@ class ZephyrIdentifier(TemplateIdentifier):
|
||||
<|assistant|>
|
||||
{{ coercion_message }}
|
||||
"""
|
||||
|
||||
|
||||
return (
|
||||
"<|system|>" in content
|
||||
and "<|user|>" in content
|
||||
and "<|assistant|>" in content
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
import json
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import tiktoken
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import load_config
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
from talemate.config import load_config
|
||||
import structlog
|
||||
import tiktoken
|
||||
|
||||
__all__ = [
|
||||
"OpenAIClient",
|
||||
]
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
|
||||
|
||||
def num_tokens_from_messages(messages: list[dict], model: str = "gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@@ -66,9 +68,11 @@ def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length:int = 16384
|
||||
model:str = "gpt-4-turbo-preview"
|
||||
max_token_length: int = 16384
|
||||
model: str = "gpt-4-turbo-preview"
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAIClient(ClientBase):
|
||||
@@ -79,13 +83,15 @@ class OpenAIClient(ClientBase):
|
||||
client_type = "openai"
|
||||
conversation_retries = 0
|
||||
auto_break_repetition_enabled = False
|
||||
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix:str = "OpenAI"
|
||||
title:str = "OpenAI"
|
||||
manual_model:bool = True
|
||||
manual_model_choices:list[str] = [
|
||||
"gpt-3.5-turbo",
|
||||
name_prefix: str = "OpenAI"
|
||||
title: str = "OpenAI"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
@@ -93,21 +99,19 @@ class OpenAIClient(ClientBase):
|
||||
"gpt-4-turbo-preview",
|
||||
]
|
||||
requires_prompt_template: bool = False
|
||||
defaults:Defaults = Defaults()
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model="gpt-4-turbo-preview", **kwargs):
|
||||
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
self.config = load_config()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
|
||||
@property
|
||||
def openai_api_key(self):
|
||||
return self.config.get("openai",{}).get("api_key")
|
||||
|
||||
return self.config.get("openai", {}).get("api_key")
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
error_action = None
|
||||
@@ -127,13 +131,13 @@ class OpenAIClient(ClientBase):
|
||||
arguments=[
|
||||
"application",
|
||||
"openai_api",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if not self.model_name:
|
||||
status = "error"
|
||||
model_name = "No model loaded"
|
||||
|
||||
|
||||
self.current_status = status
|
||||
|
||||
emit(
|
||||
@@ -145,25 +149,24 @@ class OpenAIClient(ClientBase):
|
||||
data={
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length:int=None):
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.openai_api_key:
|
||||
self.client = AsyncOpenAI(api_key="sk-1111")
|
||||
log.error("No OpenAI API key set")
|
||||
if self.api_key_status:
|
||||
self.api_key_status = False
|
||||
emit('request_client_status')
|
||||
emit('request_agent_status')
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
return
|
||||
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = "gpt-3.5-turbo-16k"
|
||||
|
||||
|
||||
model = self.model_name
|
||||
|
||||
|
||||
self.client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
if model == "gpt-3.5-turbo":
|
||||
self.max_token_length = min(max_token_length or 4096, 4096)
|
||||
@@ -175,16 +178,20 @@ class OpenAIClient(ClientBase):
|
||||
self.max_token_length = min(max_token_length or 128000, 128000)
|
||||
else:
|
||||
self.max_token_length = max_token_length or 2048
|
||||
|
||||
|
||||
|
||||
if not self.api_key_status:
|
||||
if self.api_key_status is False:
|
||||
emit('request_client_status')
|
||||
emit('request_agent_status')
|
||||
emit("request_client_status")
|
||||
emit("request_agent_status")
|
||||
self.api_key_status = True
|
||||
|
||||
log.info("openai set client", max_token_length=self.max_token_length, provided_max_token_length=max_token_length, model=model)
|
||||
|
||||
|
||||
log.info(
|
||||
"openai set client",
|
||||
max_token_length=self.max_token_length,
|
||||
provided_max_token_length=max_token_length,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
@@ -203,39 +210,37 @@ class OpenAIClient(ClientBase):
|
||||
async def status(self):
|
||||
self.emit_status()
|
||||
|
||||
|
||||
def prompt_template(self, system_message:str, prompt:str):
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
# only gpt-4-1106-preview supports json_object response coersion
|
||||
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
if right:
|
||||
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
||||
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
||||
else:
|
||||
prompt = prompt.replace("<|BOT|>", "")
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
|
||||
if not self.openai_api_key:
|
||||
raise Exception("No OpenAI API key set")
|
||||
|
||||
|
||||
# only gpt-4-* supports enforcing json object
|
||||
supports_json_object = self.model_name.startswith("gpt-4-")
|
||||
right = None
|
||||
@@ -246,26 +251,28 @@ class OpenAIClient(ClientBase):
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
system_message = {'role': 'system', 'content': self.get_system_message(kind)}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters, system_message=system_message)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[system_message, human_message], **parameters
|
||||
model=self.model_name,
|
||||
messages=[system_message, human_message],
|
||||
**parameters,
|
||||
)
|
||||
|
||||
|
||||
response = response.choices[0].message.content
|
||||
|
||||
|
||||
if right and response.startswith(right):
|
||||
response = response[len(right):].strip()
|
||||
|
||||
response = response[len(right) :].strip()
|
||||
|
||||
return response
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="OpenAI API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -1,32 +1,33 @@
|
||||
import pydantic
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
from openai import AsyncOpenAI, PermissionDeniedError, NotFoundError
|
||||
from talemate.emit import emit
|
||||
|
||||
EXPERIMENTAL_DESCRIPTION = """Use this client if you want to connect to a service implementing an OpenAI-compatible API. Success is going to depend on the level of compatibility. Use the actual OpenAI client if you want to connect to OpenAI's API."""
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
api_url:str = "http://localhost:5000"
|
||||
api_key:str = ""
|
||||
max_token_length:int = 4096
|
||||
model:str = ""
|
||||
api_url: str = "http://localhost:5000"
|
||||
api_key: str = ""
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
client_type = "openai_compat"
|
||||
conversation_retries = 5
|
||||
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
title:str = "OpenAI Compatible API"
|
||||
name_prefix:str = "OpenAI Compatible API"
|
||||
experimental:str = EXPERIMENTAL_DESCRIPTION
|
||||
enable_api_auth:bool = True
|
||||
manual_model:bool = True
|
||||
defaults:Defaults = Defaults()
|
||||
|
||||
title: str = "OpenAI Compatible API"
|
||||
name_prefix: str = "OpenAI Compatible API"
|
||||
experimental: str = EXPERIMENTAL_DESCRIPTION
|
||||
enable_api_auth: bool = True
|
||||
manual_model: bool = True
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
self.model_name = model
|
||||
super().__init__(**kwargs)
|
||||
@@ -37,23 +38,23 @@ class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key")
|
||||
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key=self.api_key)
|
||||
self.model_name = kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key=self.api_key)
|
||||
self.model_name = (
|
||||
kwargs.get("model") or kwargs.get("model_name") or self.model_name
|
||||
)
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
|
||||
keys = list(parameters.keys())
|
||||
|
||||
|
||||
valid_keys = ["temperature", "top_p"]
|
||||
|
||||
|
||||
for key in keys:
|
||||
if key not in valid_keys:
|
||||
del parameters[key]
|
||||
|
||||
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
try:
|
||||
model_name = await super().get_model_name()
|
||||
except NotFoundError as e:
|
||||
@@ -65,29 +66,28 @@ class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
# model name may be a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
|
||||
is_filepath = "/" in model_name
|
||||
is_filepath_windows = "\\" in model_name
|
||||
|
||||
|
||||
if is_filepath or is_filepath_windows:
|
||||
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], **parameters
|
||||
)
|
||||
|
||||
|
||||
return response.choices[0].message.content
|
||||
except PermissionDeniedError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
@@ -95,7 +95,9 @@ class OpenAICompatibleClient(ClientBase):
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="Error during generation (check logs)", status="error")
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
@@ -107,5 +109,5 @@ class OpenAICompatibleClient(ClientBase):
|
||||
self.max_token_length = kwargs["max_token_length"]
|
||||
if "api_key" in kwargs:
|
||||
self.api_auth = kwargs["api_key"]
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
@@ -28,18 +28,18 @@ PRESET_TALEMATE_CREATOR = {
|
||||
}
|
||||
|
||||
PRESET_LLAMA_PRECISE = {
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'top_k': 40,
|
||||
'repetition_penalty': 1.18,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.1,
|
||||
"top_k": 40,
|
||||
"repetition_penalty": 1.18,
|
||||
}
|
||||
|
||||
PRESET_DIVINE_INTELLECT = {
|
||||
'temperature': 1.31,
|
||||
'top_p': 0.14,
|
||||
'top_k': 49,
|
||||
"temperature": 1.31,
|
||||
"top_p": 0.14,
|
||||
"top_k": 49,
|
||||
"repetition_penalty_range": 1024,
|
||||
'repetition_penalty': 1.17,
|
||||
"repetition_penalty": 1.17,
|
||||
}
|
||||
|
||||
PRESET_SIMPLE_1 = {
|
||||
@@ -49,7 +49,8 @@ PRESET_SIMPLE_1 = {
|
||||
"repetition_penalty": 1.15,
|
||||
}
|
||||
|
||||
def configure(config:dict, kind:str, total_budget:int):
|
||||
|
||||
def configure(config: dict, kind: str, total_budget: int):
|
||||
"""
|
||||
Sets the config based on the kind of text to generate.
|
||||
"""
|
||||
@@ -57,19 +58,22 @@ def configure(config:dict, kind:str, total_budget:int):
|
||||
set_max_tokens(config, kind, total_budget)
|
||||
return config
|
||||
|
||||
def set_max_tokens(config:dict, kind:str, total_budget:int):
|
||||
|
||||
def set_max_tokens(config: dict, kind: str, total_budget: int):
|
||||
"""
|
||||
Sets the max_tokens in the config based on the kind of text to generate.
|
||||
"""
|
||||
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
|
||||
return config
|
||||
|
||||
def set_preset(config:dict, kind:str):
|
||||
|
||||
def set_preset(config: dict, kind: str):
|
||||
"""
|
||||
Sets the preset in the config based on the kind of text to generate.
|
||||
"""
|
||||
config.update(preset_for_kind(kind))
|
||||
|
||||
|
||||
def preset_for_kind(kind: str):
|
||||
if kind == "conversation":
|
||||
return PRESET_TALEMATE_CONVERSATION
|
||||
@@ -104,9 +108,13 @@ def preset_for_kind(kind: str):
|
||||
elif kind == "director":
|
||||
return PRESET_SIMPLE_1
|
||||
elif kind == "director_short":
|
||||
return PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
|
||||
return (
|
||||
PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
|
||||
)
|
||||
elif kind == "director_yesno":
|
||||
return PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
|
||||
return (
|
||||
PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
|
||||
)
|
||||
elif kind == "edit_dialogue":
|
||||
return PRESET_DIVINE_INTELLECT
|
||||
elif kind == "edit_add_detail":
|
||||
@@ -116,6 +124,7 @@ def preset_for_kind(kind: str):
|
||||
else:
|
||||
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
|
||||
|
||||
|
||||
def max_tokens_for_kind(kind: str, total_budget: int):
|
||||
if kind == "conversation":
|
||||
return 75 # Example value, adjust as needed
|
||||
@@ -142,15 +151,23 @@ def max_tokens_for_kind(kind: str, total_budget: int):
|
||||
elif kind == "story":
|
||||
return 300 # Example value, adjust as needed
|
||||
elif kind == "create":
|
||||
return min(1024, int(total_budget * 0.35)) # Example calculation, adjust as needed
|
||||
return min(
|
||||
1024, int(total_budget * 0.35)
|
||||
) # Example calculation, adjust as needed
|
||||
elif kind == "create_concise":
|
||||
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||
return min(
|
||||
400, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
elif kind == "create_precise":
|
||||
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||
return min(
|
||||
400, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
elif kind == "create_short":
|
||||
return 25
|
||||
elif kind == "director":
|
||||
return min(192, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||
return min(
|
||||
192, int(total_budget * 0.25)
|
||||
) # Example calculation, adjust as needed
|
||||
elif kind == "director_short":
|
||||
return 25 # Example value, adjust as needed
|
||||
elif kind == "director_yesno":
|
||||
@@ -162,4 +179,4 @@ def max_tokens_for_kind(kind: str, total_budget: int):
|
||||
elif kind == "edit_fix_exposition":
|
||||
return 1024 # Example value, adjust as needed
|
||||
else:
|
||||
return 150 # Default value if none of the kinds match
|
||||
return 150 # Default value if none of the kinds match
|
||||
|
||||
@@ -3,17 +3,17 @@ Retrieve pod information from the server which can then be used to bootstrap tal
|
||||
connection for the pod. This is a simple wrapper around the runpod module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
import runpod
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from .bootstrap import ClientBootstrap, ClientType, register_list
|
||||
import structlog
|
||||
|
||||
from talemate.config import load_config
|
||||
|
||||
import structlog
|
||||
from .bootstrap import ClientBootstrap, ClientType, register_list
|
||||
|
||||
log = structlog.get_logger("talemate.client.runpod")
|
||||
|
||||
@@ -21,73 +21,75 @@ dotenv.load_dotenv()
|
||||
|
||||
runpod.api_key = load_config().get("runpod", {}).get("api_key", "")
|
||||
|
||||
|
||||
def is_textgen_pod(pod):
|
||||
|
||||
name = pod["name"].lower()
|
||||
|
||||
|
||||
if "textgen" in name or "thebloke llms" in name:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _async_get_pods():
|
||||
"""
|
||||
asyncio wrapper around get_pods.
|
||||
"""
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, runpod.get_pods)
|
||||
|
||||
|
||||
async def get_textgen_pods():
|
||||
"""
|
||||
Return a list of text generation pods.
|
||||
"""
|
||||
|
||||
|
||||
if not runpod.api_key:
|
||||
return
|
||||
|
||||
|
||||
for pod in await _async_get_pods():
|
||||
if not pod["desiredStatus"] == "RUNNING":
|
||||
continue
|
||||
if is_textgen_pod(pod):
|
||||
yield pod
|
||||
|
||||
|
||||
|
||||
async def get_automatic1111_pods():
|
||||
"""
|
||||
Return a list of automatic1111 pods.
|
||||
"""
|
||||
|
||||
|
||||
if not runpod.api_key:
|
||||
return
|
||||
|
||||
|
||||
for pod in await _async_get_pods():
|
||||
if not pod["desiredStatus"] == "RUNNING":
|
||||
continue
|
||||
if "automatic1111" in pod["name"].lower():
|
||||
yield pod
|
||||
|
||||
|
||||
|
||||
def _client_bootstrap(client_type: ClientType, pod):
|
||||
"""
|
||||
Return a client bootstrap object for the given client type and pod.
|
||||
"""
|
||||
|
||||
|
||||
id = pod["id"]
|
||||
|
||||
|
||||
if client_type == ClientType.textgen:
|
||||
api_url = f"https://{id}-5000.proxy.runpod.net"
|
||||
elif client_type == ClientType.automatic1111:
|
||||
api_url = f"https://{id}-5000.proxy.runpod.net"
|
||||
|
||||
|
||||
return ClientBootstrap(
|
||||
client_type=client_type,
|
||||
uid=pod["id"],
|
||||
name=pod["name"],
|
||||
api_url=api_url,
|
||||
service_name="runpod"
|
||||
service_name="runpod",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@register_list("runpod")
|
||||
async def client_bootstrap_list():
|
||||
@@ -97,13 +99,13 @@ async def client_bootstrap_list():
|
||||
textgen_pods = []
|
||||
async for pod in get_textgen_pods():
|
||||
textgen_pods.append(pod)
|
||||
|
||||
|
||||
automatic1111_pods = []
|
||||
async for pod in get_automatic1111_pods():
|
||||
automatic1111_pods.append(pod)
|
||||
|
||||
|
||||
for pod in textgen_pods:
|
||||
yield _client_bootstrap(ClientType.textgen, pod)
|
||||
|
||||
|
||||
for pod in automatic1111_pods:
|
||||
yield _client_bootstrap(ClientType.automatic1111, pod)
|
||||
yield _client_bootstrap(ClientType.automatic1111, pod)
|
||||
|
||||
@@ -18,4 +18,24 @@ EDITOR = str(Prompt.get("editor.system"))
|
||||
|
||||
WORLD_STATE = str(Prompt.get("world_state.system-analyst"))
|
||||
|
||||
SUMMARIZE = str(Prompt.get("summarizer.system"))
|
||||
SUMMARIZE = str(Prompt.get("summarizer.system"))
|
||||
|
||||
# CAREBEAR PROMPTS
|
||||
|
||||
ROLEPLAY_NO_DECENSOR = str(Prompt.get("conversation.system-no-decensor"))
|
||||
|
||||
NARRATOR_NO_DECENSOR = str(Prompt.get("narrator.system-no-decensor"))
|
||||
|
||||
CREATOR_NO_DECENSOR = str(Prompt.get("creator.system-no-decensor"))
|
||||
|
||||
DIRECTOR_NO_DECENSOR = str(Prompt.get("director.system-no-decensor"))
|
||||
|
||||
ANALYST_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-no-decensor"))
|
||||
|
||||
ANALYST_FREEFORM_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-freeform-no-decensor"))
|
||||
|
||||
EDITOR_NO_DECENSOR = str(Prompt.get("editor.system-no-decensor"))
|
||||
|
||||
WORLD_STATE_NO_DECENSOR = str(Prompt.get("world_state.system-analyst-no-decensor"))
|
||||
|
||||
SUMMARIZE_NO_DECENSOR = str(Prompt.get("summarizer.system-no-decensor"))
|
||||
|
||||
@@ -1,77 +1,94 @@
|
||||
from talemate.client.base import ClientBase, STOPPING_STRINGS
|
||||
from talemate.client.registry import register
|
||||
from openai import AsyncOpenAI
|
||||
import httpx
|
||||
import random
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase
|
||||
from talemate.client.registry import register
|
||||
|
||||
log = structlog.get_logger("talemate.client.textgenwebui")
|
||||
|
||||
|
||||
@register()
|
||||
class TextGeneratorWebuiClient(ClientBase):
|
||||
|
||||
client_type = "textgenwebui"
|
||||
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix:str = "TextGenWebUI"
|
||||
title:str = "Text-Generation-WebUI (ooba)"
|
||||
|
||||
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||
name_prefix: str = "TextGenWebUI"
|
||||
title: str = "Text-Generation-WebUI (ooba)"
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
|
||||
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get(
|
||||
"extra_stopping_strings", []
|
||||
)
|
||||
# is this needed?
|
||||
parameters["max_new_tokens"] = parameters["max_tokens"]
|
||||
parameters["stop"] = parameters["stopping_strings"]
|
||||
|
||||
|
||||
# Half temperature on -Yi- models
|
||||
if self.model_name and "-yi-" in self.model_name.lower() and parameters["temperature"] > 0.1:
|
||||
if (
|
||||
self.model_name
|
||||
and "-yi-" in self.model_name.lower()
|
||||
and parameters["temperature"] > 0.1
|
||||
):
|
||||
parameters["temperature"] = parameters["temperature"] / 2
|
||||
log.debug("halfing temperature for -yi- model", temperature=parameters["temperature"])
|
||||
|
||||
log.debug(
|
||||
"halfing temperature for -yi- model",
|
||||
temperature=parameters["temperature"],
|
||||
)
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
|
||||
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
async def get_model_name(self):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{self.api_url}/v1/internal/model/info", timeout=2)
|
||||
response = await client.get(
|
||||
f"{self.api_url}/v1/internal/model/info", timeout=2
|
||||
)
|
||||
if response.status_code == 404:
|
||||
raise Exception("Could not find model info (wrong api version?)")
|
||||
response_data = response.json()
|
||||
model_name = response_data.get("model_name")
|
||||
|
||||
|
||||
if model_name == "None":
|
||||
model_name = None
|
||||
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(f"{self.api_url}/v1/completions", json=parameters, timeout=None, headers=headers)
|
||||
response = await client.post(
|
||||
f"{self.api_url}/v1/completions",
|
||||
json=parameters,
|
||||
timeout=None,
|
||||
headers=headers,
|
||||
)
|
||||
response_data = response.json()
|
||||
return response_data["choices"][0]["text"]
|
||||
|
||||
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config["repetition_penalty"]
|
||||
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
prompt_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
||||
prompt_config["repetition_penalty"] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
|
||||
@@ -1,32 +1,33 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
|
||||
|
||||
def jiggle_randomness(prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config["repetition_penalty"]
|
||||
|
||||
|
||||
copied_config = copy.deepcopy(prompt_config)
|
||||
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
||||
|
||||
copied_config["repetition_penalty"] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
|
||||
return copied_config
|
||||
|
||||
|
||||
def jiggle_enabled_for(kind:str):
|
||||
|
||||
|
||||
|
||||
def jiggle_enabled_for(kind: str):
|
||||
if kind in ["conversation", "story"]:
|
||||
return True
|
||||
|
||||
|
||||
if kind.startswith("narrate"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
@@ -12,17 +12,17 @@ from .cmd_memget import CmdMemget
|
||||
from .cmd_memset import CmdMemset
|
||||
from .cmd_narrate import *
|
||||
from .cmd_rebuild_archive import CmdRebuildArchive
|
||||
from .cmd_remove_character import CmdRemoveCharacter
|
||||
from .cmd_rename import CmdRename
|
||||
from .cmd_rerun import *
|
||||
from .cmd_reset import CmdReset
|
||||
from .cmd_rm import CmdRm
|
||||
from .cmd_remove_character import CmdRemoveCharacter
|
||||
from .cmd_run_helios_test import CmdHeliosTest
|
||||
from .cmd_save import CmdSave
|
||||
from .cmd_save_as import CmdSaveAs
|
||||
from .cmd_save_characters import CmdSaveCharacters
|
||||
from .cmd_setenv import CmdSetEnvironmentToScene, CmdSetEnvironmentToCreative
|
||||
from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene
|
||||
from .cmd_time_util import *
|
||||
from .cmd_tts import *
|
||||
from .cmd_world_state import *
|
||||
from .cmd_run_helios_test import CmdHeliosTest
|
||||
from .manager import Manager
|
||||
from .manager import Manager
|
||||
|
||||
@@ -41,7 +41,7 @@ class TalemateCommand(Emitter, ABC):
|
||||
raise NotImplementedError(
|
||||
"TalemateCommand.run() must be implemented by subclass"
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def verbose_name(self):
|
||||
if self.label:
|
||||
@@ -50,6 +50,6 @@ class TalemateCommand(Emitter, ABC):
|
||||
|
||||
def command_start(self):
|
||||
emit("command_status", self.verbose_name, status="started")
|
||||
|
||||
|
||||
def command_end(self):
|
||||
emit("command_status", self.verbose_name, status="ended")
|
||||
emit("command_status", self.verbose_name, status="ended")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import structlog
|
||||
|
||||
from talemate.character import activate_character, deactivate_character
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import wait_for_input, emit
|
||||
from talemate.character import deactivate_character, activate_character
|
||||
from talemate.emit import emit, wait_for_input
|
||||
from talemate.instance import get_agent
|
||||
|
||||
log = structlog.get_logger("talemate.cmd.characters")
|
||||
@@ -13,6 +13,7 @@ __all__ = [
|
||||
"CmdActivateCharacter",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdDeactivateCharacter(TalemateCommand):
|
||||
"""
|
||||
@@ -22,61 +23,77 @@ class CmdDeactivateCharacter(TalemateCommand):
|
||||
name = "character_deactivate"
|
||||
description = "Will deactivate a character"
|
||||
aliases = ["char_d"]
|
||||
|
||||
|
||||
label = "Character exit"
|
||||
|
||||
async def run(self):
|
||||
|
||||
narrator = get_agent("narrator")
|
||||
world_state = get_agent("world_state")
|
||||
characters = list([character.name for character in self.scene.get_npc_characters()])
|
||||
|
||||
characters = list(
|
||||
[character.name for character in self.scene.get_npc_characters()]
|
||||
)
|
||||
|
||||
if not characters:
|
||||
emit("status", message="No characters found", status="error")
|
||||
return True
|
||||
|
||||
|
||||
if self.args:
|
||||
character_name = self.args[0]
|
||||
else:
|
||||
character_name = await wait_for_input("Which character do you want to deactivate?", data={
|
||||
"input_type": "select",
|
||||
"choices": characters,
|
||||
})
|
||||
|
||||
character_name = await wait_for_input(
|
||||
"Which character do you want to deactivate?",
|
||||
data={
|
||||
"input_type": "select",
|
||||
"choices": characters,
|
||||
},
|
||||
)
|
||||
|
||||
if not character_name:
|
||||
emit("status", message="No character selected", status="error")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
never_narrate = len(self.args) > 1 and self.args[1] == "no"
|
||||
|
||||
|
||||
if not never_narrate:
|
||||
is_present = await world_state.is_character_present(character_name)
|
||||
is_leaving = await world_state.is_character_leaving(character_name)
|
||||
log.debug("deactivate_character", character_name=character_name, is_present=is_present, is_leaving=is_leaving, never_narrate=never_narrate)
|
||||
log.debug(
|
||||
"deactivate_character",
|
||||
character_name=character_name,
|
||||
is_present=is_present,
|
||||
is_leaving=is_leaving,
|
||||
never_narrate=never_narrate,
|
||||
)
|
||||
else:
|
||||
is_present = False
|
||||
is_leaving = True
|
||||
log.debug("deactivate_character", character_name=character_name, never_narrate=never_narrate)
|
||||
|
||||
log.debug(
|
||||
"deactivate_character",
|
||||
character_name=character_name,
|
||||
never_narrate=never_narrate,
|
||||
)
|
||||
|
||||
if is_present and not is_leaving and not never_narrate:
|
||||
direction = await wait_for_input(f"How does {character_name} exit the scene? (leave blank for AI to decide)")
|
||||
direction = await wait_for_input(
|
||||
f"How does {character_name} exit the scene? (leave blank for AI to decide)"
|
||||
)
|
||||
message = await narrator.action_to_narration(
|
||||
"narrate_character_exit",
|
||||
self.scene.get_character(character_name),
|
||||
direction = direction,
|
||||
direction=direction,
|
||||
)
|
||||
self.narrator_message(message)
|
||||
|
||||
await deactivate_character(self.scene, character_name)
|
||||
|
||||
|
||||
await deactivate_character(self.scene, character_name)
|
||||
|
||||
emit("status", message=f"Deactivated {character_name}", status="success")
|
||||
|
||||
self.scene.emit_status()
|
||||
self.scene.world_state.emit()
|
||||
self.scene.world_state.emit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdActivateCharacter(TalemateCommand):
|
||||
"""
|
||||
@@ -86,57 +103,70 @@ class CmdActivateCharacter(TalemateCommand):
|
||||
name = "character_activate"
|
||||
description = "Will activate a character"
|
||||
aliases = ["char_a"]
|
||||
|
||||
|
||||
label = "Character enter"
|
||||
|
||||
async def run(self):
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
narrator = get_agent("narrator")
|
||||
characters = list(self.scene.inactive_characters.keys())
|
||||
|
||||
|
||||
if not characters:
|
||||
emit("status", message="No characters found", status="error")
|
||||
return True
|
||||
|
||||
|
||||
if self.args:
|
||||
character_name = self.args[0]
|
||||
if character_name not in characters:
|
||||
emit("status", message="Character not found", status="error")
|
||||
return True
|
||||
else:
|
||||
character_name = await wait_for_input("Which character do you want to activate?", data={
|
||||
"input_type": "select",
|
||||
"choices": characters,
|
||||
})
|
||||
|
||||
character_name = await wait_for_input(
|
||||
"Which character do you want to activate?",
|
||||
data={
|
||||
"input_type": "select",
|
||||
"choices": characters,
|
||||
},
|
||||
)
|
||||
|
||||
if not character_name:
|
||||
emit("status", message="No character selected", status="error")
|
||||
return True
|
||||
|
||||
|
||||
never_narrate = len(self.args) > 1 and self.args[1] == "no"
|
||||
|
||||
|
||||
if not never_narrate:
|
||||
is_present = await world_state.is_character_present(character_name)
|
||||
log.debug("activate_character", character_name=character_name, is_present=is_present, never_narrate=never_narrate)
|
||||
log.debug(
|
||||
"activate_character",
|
||||
character_name=character_name,
|
||||
is_present=is_present,
|
||||
never_narrate=never_narrate,
|
||||
)
|
||||
else:
|
||||
is_present = True
|
||||
log.debug("activate_character", character_name=character_name, never_narrate=never_narrate)
|
||||
|
||||
log.debug(
|
||||
"activate_character",
|
||||
character_name=character_name,
|
||||
never_narrate=never_narrate,
|
||||
)
|
||||
|
||||
await activate_character(self.scene, character_name)
|
||||
|
||||
|
||||
if not is_present and not never_narrate:
|
||||
direction = await wait_for_input(f"How does {character_name} enter the scene? (leave blank for AI to decide)")
|
||||
direction = await wait_for_input(
|
||||
f"How does {character_name} enter the scene? (leave blank for AI to decide)"
|
||||
)
|
||||
message = await narrator.action_to_narration(
|
||||
"narrate_character_entry",
|
||||
self.scene.get_character(character_name),
|
||||
direction = direction,
|
||||
direction=direction,
|
||||
)
|
||||
self.narrator_message(message)
|
||||
|
||||
|
||||
emit("status", message=f"Activated {character_name}", status="success")
|
||||
|
||||
|
||||
self.scene.emit_status()
|
||||
self.scene.world_state.emit()
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
@@ -12,6 +12,7 @@ __all__ = [
|
||||
"CmdRunAutomatic",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdDebugOn(TalemateCommand):
|
||||
"""
|
||||
@@ -26,6 +27,7 @@ class CmdDebugOn(TalemateCommand):
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@register
|
||||
class CmdDebugOff(TalemateCommand):
|
||||
"""
|
||||
@@ -46,66 +48,64 @@ class CmdPromptChangeSectioning(TalemateCommand):
|
||||
"""
|
||||
Command class for the '_prompt_change_sectioning' command
|
||||
"""
|
||||
|
||||
|
||||
name = "_prompt_change_sectioning"
|
||||
description = "Change the sectioning handler for the prompt system"
|
||||
aliases = []
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
if not self.args:
|
||||
self.emit("system", "You must specify a sectioning handler")
|
||||
return
|
||||
|
||||
|
||||
handler_name = self.args[0]
|
||||
|
||||
|
||||
set_default_sectioning_handler(handler_name)
|
||||
|
||||
|
||||
self.emit("system", f"Sectioning handler set to {handler_name}")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdRunAutomatic(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'run_automatic' command
|
||||
"""
|
||||
|
||||
|
||||
name = "run_automatic"
|
||||
description = "Will make the player character AI controlled for n turns"
|
||||
aliases = ["auto"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
|
||||
if self.args:
|
||||
turns = int(self.args[0])
|
||||
else:
|
||||
turns = 10
|
||||
|
||||
|
||||
self.emit("system", f"Making player character AI controlled for {turns} turns")
|
||||
self.scene.get_player_character().actor.ai_controlled = turns
|
||||
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdLongTermMemoryStats(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'long_term_memory_stats' command
|
||||
"""
|
||||
|
||||
|
||||
name = "long_term_memory_stats"
|
||||
description = "Show stats for the long term memory"
|
||||
aliases = ["ltm_stats"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
memory = self.scene.get_helper("memory").agent
|
||||
|
||||
|
||||
count = await memory.count()
|
||||
db_name = memory.db_name
|
||||
|
||||
self.emit("system", f"Long term memory for {self.scene.name} has {count} entries in the {db_name} database")
|
||||
|
||||
self.emit(
|
||||
"system",
|
||||
f"Long term memory for {self.scene.name} has {count} entries in the {db_name} database",
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
@@ -113,35 +113,34 @@ class CmdLongTermMemoryReset(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'long_term_memory_reset' command
|
||||
"""
|
||||
|
||||
|
||||
name = "long_term_memory_reset"
|
||||
description = "Reset the long term memory"
|
||||
aliases = ["ltm_reset"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
await self.scene.commit_to_memory()
|
||||
|
||||
|
||||
self.emit("system", f"Long term memory for {self.scene.name} has been reset")
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdSetContentContext(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'set_content_context' command
|
||||
"""
|
||||
|
||||
|
||||
name = "set_content_context"
|
||||
description = "Set the content context for the scene"
|
||||
aliases = ["set_context"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
if not self.args:
|
||||
self.emit("system", "You must specify a context")
|
||||
return
|
||||
|
||||
|
||||
context = self.args[0]
|
||||
|
||||
|
||||
self.scene.context = context
|
||||
|
||||
self.emit("system", f"Content context set to {context}")
|
||||
|
||||
self.emit("system", f"Content context set to {context}")
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.scene_message import DirectorMessage
|
||||
from talemate.emit import wait_for_input
|
||||
from talemate.scene_message import DirectorMessage
|
||||
|
||||
__all__ = [
|
||||
"CmdAIDialogue",
|
||||
@@ -11,6 +12,7 @@ __all__ = [
|
||||
"CmdAIDialogueDirected",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdAIDialogue(TalemateCommand):
|
||||
"""
|
||||
@@ -23,15 +25,14 @@ class CmdAIDialogue(TalemateCommand):
|
||||
|
||||
async def run(self):
|
||||
conversation_agent = self.scene.get_helper("conversation").agent
|
||||
|
||||
|
||||
actor = None
|
||||
|
||||
|
||||
# if there is only one npc in the scene, use that
|
||||
|
||||
|
||||
if len(self.scene.npc_character_names) == 1:
|
||||
actor = list(self.scene.get_npc_characters())[0].actor
|
||||
else:
|
||||
|
||||
if conversation_agent.actions["natural_flow"].enabled:
|
||||
await conversation_agent.apply_natural_flow(force=True, npcs_only=True)
|
||||
character_name = self.scene.next_actor
|
||||
@@ -41,83 +42,83 @@ class CmdAIDialogue(TalemateCommand):
|
||||
else:
|
||||
# randomly select an actor
|
||||
actor = random.choice(list(self.scene.get_npc_characters())).actor
|
||||
|
||||
|
||||
if not actor:
|
||||
return
|
||||
|
||||
|
||||
messages = await actor.talk()
|
||||
|
||||
|
||||
self.scene.process_npc_dialogue(actor, messages)
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdAIDialogueSelective(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'ai_dialogue_selective' command
|
||||
|
||||
|
||||
Will allow the player to select which npc dialogue will be generated
|
||||
for
|
||||
"""
|
||||
|
||||
|
||||
name = "ai_dialogue_selective"
|
||||
|
||||
|
||||
description = "Generate dialogue for an AI selected actor"
|
||||
|
||||
|
||||
aliases = ["dlg_selective"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
npc_name = self.args[0]
|
||||
|
||||
|
||||
character = self.scene.get_character(npc_name)
|
||||
|
||||
|
||||
if not character:
|
||||
self.emit("system_message", message=f"Character not found: {npc_name}")
|
||||
return
|
||||
|
||||
|
||||
actor = character.actor
|
||||
|
||||
|
||||
messages = await actor.talk()
|
||||
|
||||
|
||||
self.scene.process_npc_dialogue(actor, messages)
|
||||
|
||||
|
||||
@register
|
||||
class CmdAIDialogueDirected(TalemateCommand):
|
||||
"""
|
||||
Command class for the 'ai_dialogue_directed' command
|
||||
|
||||
|
||||
Will allow the player to select which npc dialogue will be generated
|
||||
for
|
||||
"""
|
||||
|
||||
|
||||
name = "ai_dialogue_directed"
|
||||
|
||||
|
||||
description = "Generate dialogue for an AI selected actor"
|
||||
|
||||
|
||||
aliases = ["dlg_directed"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
npc_name = self.args[0]
|
||||
|
||||
|
||||
character = self.scene.get_character(npc_name)
|
||||
|
||||
|
||||
if not character:
|
||||
self.emit("system_message", message=f"Character not found: {npc_name}")
|
||||
return
|
||||
|
||||
prefix = f"Director instructs {character.name}: \"To progress the scene, i want you to"
|
||||
|
||||
direction = await wait_for_input(prefix+"... (enter your instructions)")
|
||||
direction = f"{prefix} {direction}\""
|
||||
|
||||
|
||||
prefix = f'Director instructs {character.name}: "To progress the scene, i want you to'
|
||||
|
||||
direction = await wait_for_input(prefix + "... (enter your instructions)")
|
||||
direction = f'{prefix} {direction}"'
|
||||
|
||||
director_message = DirectorMessage(direction, source=character.name)
|
||||
|
||||
|
||||
self.emit("director", director_message, character=character)
|
||||
|
||||
|
||||
self.scene.push_history(director_message)
|
||||
|
||||
|
||||
actor = character.actor
|
||||
|
||||
|
||||
messages = await actor.talk()
|
||||
|
||||
self.scene.process_npc_dialogue(actor, messages)
|
||||
|
||||
self.scene.process_npc_dialogue(actor, messages)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import wait_for_input, emit
|
||||
from talemate.util import colored_text, wrap_text
|
||||
from talemate.emit import emit, wait_for_input
|
||||
from talemate.scene_message import DirectorMessage
|
||||
from talemate.util import colored_text, wrap_text
|
||||
|
||||
|
||||
@register
|
||||
@@ -21,9 +21,9 @@ class CmdDirectorDirect(TalemateCommand):
|
||||
if not director:
|
||||
self.system_message("No director found")
|
||||
return True
|
||||
|
||||
|
||||
npc_count = self.scene.num_npc_characters()
|
||||
|
||||
|
||||
if npc_count == 1:
|
||||
character = list(self.scene.get_npc_characters())[0]
|
||||
elif npc_count > 1:
|
||||
@@ -36,17 +36,20 @@ class CmdDirectorDirect(TalemateCommand):
|
||||
if not character:
|
||||
self.system_message(f"Character not found: {name}")
|
||||
return True
|
||||
|
||||
goal = await wait_for_input(f"Enter a new goal for the director to direct {character.name}")
|
||||
|
||||
|
||||
goal = await wait_for_input(
|
||||
f"Enter a new goal for the director to direct {character.name}"
|
||||
)
|
||||
|
||||
if not goal.strip():
|
||||
self.system_message("No goal specified")
|
||||
return True
|
||||
|
||||
|
||||
director.agent.actions["direct"].config["prompt"].value = goal
|
||||
|
||||
|
||||
await director.agent.direct_character(character, goal)
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdDirectorDirectWithOverride(CmdDirectorDirect):
|
||||
"""
|
||||
@@ -54,7 +57,9 @@ class CmdDirectorDirectWithOverride(CmdDirectorDirect):
|
||||
"""
|
||||
|
||||
name = "director_with_goal"
|
||||
description = "Calls a director to give directionts to a character (with goal specified)"
|
||||
description = (
|
||||
"Calls a director to give directionts to a character (with goal specified)"
|
||||
)
|
||||
aliases = ["direct_g"]
|
||||
|
||||
async def run(self):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
|
||||
|
||||
@register
|
||||
class CmdMemget(TalemateCommand):
|
||||
"""
|
||||
@@ -16,4 +17,4 @@ class CmdMemget(TalemateCommand):
|
||||
memories = self.scene.get_helper("memory").agent.get(query)
|
||||
|
||||
for memory in memories:
|
||||
self.emit("narrator", memory["text"])
|
||||
self.emit("narrator", memory["text"])
|
||||
|
||||
@@ -2,9 +2,9 @@ import asyncio
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.util import colored_text, wrap_text
|
||||
from talemate.scene_message import NarratorMessage
|
||||
from talemate.emit import wait_for_input
|
||||
from talemate.scene_message import NarratorMessage
|
||||
from talemate.util import colored_text, wrap_text
|
||||
|
||||
__all__ = [
|
||||
"CmdNarrate",
|
||||
@@ -14,6 +14,7 @@ __all__ = [
|
||||
"CmdNarrateC",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdNarrate(TalemateCommand):
|
||||
"""
|
||||
@@ -33,7 +34,7 @@ class CmdNarrate(TalemateCommand):
|
||||
|
||||
narration = await narrator.agent.narrate_scene()
|
||||
message = NarratorMessage(narration, source="narrate_scene")
|
||||
|
||||
|
||||
self.narrator_message(message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
@@ -58,17 +59,22 @@ class CmdNarrateQ(TalemateCommand):
|
||||
|
||||
if self.args:
|
||||
query = self.args[0]
|
||||
at_the_end = (self.args[1].lower() == "true") if len(self.args) > 1 else False
|
||||
at_the_end = (
|
||||
(self.args[1].lower() == "true") if len(self.args) > 1 else False
|
||||
)
|
||||
else:
|
||||
query = await wait_for_input("Enter query: ")
|
||||
at_the_end = False
|
||||
|
||||
narration = await narrator.agent.narrate_query(query, at_the_end=at_the_end)
|
||||
message = NarratorMessage(narration, source=f"narrate_query:{query.replace(':', '-')}")
|
||||
|
||||
message = NarratorMessage(
|
||||
narration, source=f"narrate_query:{query.replace(':', '-')}"
|
||||
)
|
||||
|
||||
self.narrator_message(message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdNarrateProgress(TalemateCommand):
|
||||
"""
|
||||
@@ -89,10 +95,11 @@ class CmdNarrateProgress(TalemateCommand):
|
||||
narration = await narrator.agent.progress_story()
|
||||
|
||||
message = NarratorMessage(narration, source="progress_story")
|
||||
|
||||
|
||||
self.narrator_message(message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
@register
|
||||
class CmdNarrateProgressDirected(TalemateCommand):
|
||||
"""
|
||||
@@ -105,16 +112,17 @@ class CmdNarrateProgressDirected(TalemateCommand):
|
||||
|
||||
async def run(self):
|
||||
narrator = self.scene.get_helper("narrator")
|
||||
|
||||
|
||||
direction = await wait_for_input("Enter direction for the narrator: ")
|
||||
|
||||
narration = await narrator.agent.progress_story(narrative_direction=direction)
|
||||
|
||||
message = NarratorMessage(narration, source=f"progress_story:{direction}")
|
||||
|
||||
|
||||
self.narrator_message(message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
@register
|
||||
class CmdNarrateC(TalemateCommand):
|
||||
"""
|
||||
@@ -149,7 +157,8 @@ class CmdNarrateC(TalemateCommand):
|
||||
|
||||
self.narrator_message(message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdNarrateDialogue(TalemateCommand):
|
||||
"""
|
||||
@@ -165,23 +174,25 @@ class CmdNarrateDialogue(TalemateCommand):
|
||||
narrator = self.scene.get_helper("narrator")
|
||||
|
||||
character_messages = self.scene.collect_messages("character", max_iterations=5)
|
||||
|
||||
|
||||
if not character_messages:
|
||||
self.system_message("No recent dialogue message found")
|
||||
return True
|
||||
|
||||
|
||||
character_message = character_messages[0]
|
||||
|
||||
|
||||
character_name = character_message.character_name
|
||||
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
|
||||
if not character:
|
||||
self.system_message(f"Character not found: {character_name}")
|
||||
return True
|
||||
|
||||
|
||||
narration = await narrator.agent.narrate_after_dialogue(character)
|
||||
message = NarratorMessage(narration, source=f"narrate_dialogue:{character.name}")
|
||||
message = NarratorMessage(
|
||||
narration, source=f"narrate_dialogue:{character.name}"
|
||||
)
|
||||
|
||||
self.narrator_message(message)
|
||||
self.scene.push_history(message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
@@ -20,7 +20,7 @@ class CmdRebuildArchive(TalemateCommand):
|
||||
if not summarizer:
|
||||
self.system_message("No summarizer found")
|
||||
return True
|
||||
|
||||
|
||||
# clear out archived history, but keep pre-established history
|
||||
self.scene.archived_history = [
|
||||
ah for ah in self.scene.archived_history if ah.get("end") is None
|
||||
|
||||
@@ -14,38 +14,37 @@ class CmdRemoveCharacter(TalemateCommand):
|
||||
aliases = ["rmc"]
|
||||
|
||||
async def run(self):
|
||||
|
||||
characters = list([character.name for character in self.scene.get_characters()])
|
||||
|
||||
|
||||
if not characters:
|
||||
self.system_message("No characters found")
|
||||
return True
|
||||
|
||||
|
||||
if self.args:
|
||||
character_name = self.args[0]
|
||||
else:
|
||||
character_name = await wait_for_input("Which character do you want to remove?", data={
|
||||
"input_type": "select",
|
||||
"choices": characters,
|
||||
})
|
||||
|
||||
character_name = await wait_for_input(
|
||||
"Which character do you want to remove?",
|
||||
data={
|
||||
"input_type": "select",
|
||||
"choices": characters,
|
||||
},
|
||||
)
|
||||
|
||||
if not character_name:
|
||||
self.system_message("No character selected")
|
||||
return True
|
||||
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
|
||||
if not character:
|
||||
self.system_message(f"Character {character_name} not found")
|
||||
return True
|
||||
|
||||
|
||||
await self.scene.remove_actor(character.actor)
|
||||
|
||||
|
||||
self.system_message(f"Removed {character.name} from scene")
|
||||
|
||||
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
|
||||
from talemate.emit import wait_for_input
|
||||
|
||||
|
||||
@@ -23,20 +22,23 @@ class CmdRename(TalemateCommand):
|
||||
character_name = self.args[0]
|
||||
else:
|
||||
character_names = self.scene.character_names
|
||||
character_name = await wait_for_input("Which character do you want to rename?", data={
|
||||
"input_type": "select",
|
||||
"choices": character_names,
|
||||
})
|
||||
|
||||
character_name = await wait_for_input(
|
||||
"Which character do you want to rename?",
|
||||
data={
|
||||
"input_type": "select",
|
||||
"choices": character_names,
|
||||
},
|
||||
)
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
|
||||
if not character:
|
||||
self.system_message(f"Character {character_name} not found")
|
||||
return True
|
||||
|
||||
|
||||
name = await wait_for_input("Enter new name: ")
|
||||
|
||||
character.rename(name)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from talemate.client.context import ClientContext
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.client.context import ClientContext
|
||||
from talemate.context import RerunContext
|
||||
|
||||
from talemate.emit import wait_for_input
|
||||
|
||||
__all__ = [
|
||||
@@ -10,6 +9,7 @@ __all__ = [
|
||||
"CmdRerunWithDirection",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdRerun(TalemateCommand):
|
||||
"""
|
||||
@@ -24,8 +24,8 @@ class CmdRerun(TalemateCommand):
|
||||
nuke_repetition = self.args[0] if self.args else 0.0
|
||||
with ClientContext(nuke_repetition=nuke_repetition):
|
||||
await self.scene.rerun()
|
||||
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdRerunWithDirection(TalemateCommand):
|
||||
"""
|
||||
@@ -35,25 +35,25 @@ class CmdRerunWithDirection(TalemateCommand):
|
||||
name = "rerun_directed"
|
||||
description = "Rerun the scene with a direction"
|
||||
aliases = ["rrd"]
|
||||
|
||||
|
||||
label = "Directed Rerun"
|
||||
|
||||
async def run(self):
|
||||
nuke_repetition = self.args[0] if self.args else 0.0
|
||||
method = self.args[1] if len(self.args) > 1 else "replace"
|
||||
|
||||
|
||||
|
||||
|
||||
if method not in ["replace", "edit"]:
|
||||
raise ValueError(f"Unknown method: {method}. Valid methods are 'replace' and 'edit'.")
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown method: {method}. Valid methods are 'replace' and 'edit'."
|
||||
)
|
||||
|
||||
if method == "replace":
|
||||
hint = ""
|
||||
else:
|
||||
hint = " (subtle change to previous generation)"
|
||||
|
||||
|
||||
direction = await wait_for_input(f"Instructions for regeneration{hint}: ")
|
||||
|
||||
|
||||
with RerunContext(self.scene, direction=direction, method=method):
|
||||
with ClientContext(direction=direction, nuke_repetition=nuke_repetition):
|
||||
await self.scene.rerun()
|
||||
await self.scene.rerun()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
|
||||
from talemate.emit import wait_for_input, wait_for_input_yesno, emit
|
||||
from talemate.emit import emit, wait_for_input, wait_for_input_yesno
|
||||
from talemate.exceptions import ResetScene
|
||||
|
||||
|
||||
@@ -16,13 +15,12 @@ class CmdReset(TalemateCommand):
|
||||
aliases = [""]
|
||||
|
||||
async def run(self):
|
||||
|
||||
reset = await wait_for_input_yesno("Reset the scene?")
|
||||
|
||||
|
||||
if reset.lower() not in ["yes", "y"]:
|
||||
self.system_message("Reset cancelled")
|
||||
return True
|
||||
|
||||
|
||||
self.scene.reset()
|
||||
|
||||
|
||||
raise ResetScene()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
|
||||
from talemate.emit import wait_for_input, wait_for_input_yesno, emit
|
||||
from talemate.emit import emit, wait_for_input, wait_for_input_yesno
|
||||
from talemate.exceptions import ResetScene
|
||||
|
||||
|
||||
@@ -14,26 +13,25 @@ class CmdHeliosTest(TalemateCommand):
|
||||
name = "helios_test"
|
||||
description = "Runs the helios test"
|
||||
aliases = [""]
|
||||
|
||||
|
||||
analyst_script = [
|
||||
"Good morning helios, how are you today? Are you ready to run some tests?",
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
|
||||
if self.scene.name != "Helios Test Arena":
|
||||
emit("system", "You are not in the Helios Test Arena")
|
||||
|
||||
|
||||
self.scene.reset()
|
||||
|
||||
|
||||
self.scene
|
||||
|
||||
player = self.scene.get_player_character()
|
||||
player.actor.muted = 10
|
||||
|
||||
|
||||
analyst = self.scene.get_character("The analyst")
|
||||
actor = analyst.actor
|
||||
|
||||
|
||||
actor.script = self.analyst_script
|
||||
|
||||
|
||||
raise ResetScene()
|
||||
|
||||
@@ -2,9 +2,8 @@ import asyncio
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.exceptions import RestartSceneLoop
|
||||
|
||||
from talemate.emit import emit
|
||||
from talemate.exceptions import RestartSceneLoop
|
||||
|
||||
|
||||
@register
|
||||
@@ -19,21 +18,20 @@ class CmdSetEnvironmentToScene(TalemateCommand):
|
||||
|
||||
async def run(self):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
player_character = self.scene.get_player_character()
|
||||
|
||||
|
||||
if not player_character:
|
||||
self.system_message("No player character found")
|
||||
return True
|
||||
|
||||
|
||||
self.scene.set_environment("scene")
|
||||
|
||||
|
||||
emit("status", message="Switched to gameplay", status="info")
|
||||
|
||||
|
||||
raise RestartSceneLoop()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdSetEnvironmentToCreative(TalemateCommand):
|
||||
"""
|
||||
@@ -45,8 +43,7 @@ class CmdSetEnvironmentToCreative(TalemateCommand):
|
||||
aliases = [""]
|
||||
|
||||
async def run(self):
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.scene.set_environment("creative")
|
||||
|
||||
|
||||
raise RestartSceneLoop()
|
||||
|
||||
@@ -5,16 +5,18 @@ Commands to manage scene timescale
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import isodate
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import wait_for_input
|
||||
import talemate.instance as instance
|
||||
import isodate
|
||||
|
||||
__all__ = [
|
||||
"CmdAdvanceTime",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdAdvanceTime(TalemateCommand):
|
||||
"""
|
||||
@@ -29,20 +31,23 @@ class CmdAdvanceTime(TalemateCommand):
|
||||
if not self.args:
|
||||
self.emit("system", "You must specify an amount of time to advance")
|
||||
return
|
||||
|
||||
|
||||
|
||||
narrator = instance.get_agent("narrator")
|
||||
narration_prompt = None
|
||||
|
||||
|
||||
# if narrator has narrate_time_passage action enabled ask the user
|
||||
# for a prompt to guide the narration
|
||||
|
||||
if narrator.actions["narrate_time_passage"].enabled and narrator.actions["narrate_time_passage"].config["ask_for_prompt"].value:
|
||||
|
||||
narration_prompt = await wait_for_input("Enter a prompt to guide the time passage narration (or leave blank): ")
|
||||
|
||||
|
||||
if (
|
||||
narrator.actions["narrate_time_passage"].enabled
|
||||
and narrator.actions["narrate_time_passage"].config["ask_for_prompt"].value
|
||||
):
|
||||
narration_prompt = await wait_for_input(
|
||||
"Enter a prompt to guide the time passage narration (or leave blank): "
|
||||
)
|
||||
|
||||
if not narration_prompt.strip():
|
||||
narration_prompt = None
|
||||
|
||||
|
||||
world_state = instance.get_agent("world_state")
|
||||
await world_state.advance_time(self.args[0], narration_prompt)
|
||||
await world_state.advance_time(self.args[0], narration_prompt)
|
||||
|
||||
@@ -3,13 +3,14 @@ import logging
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.prompts.base import set_default_sectioning_handler
|
||||
from talemate.instance import get_agent
|
||||
from talemate.prompts.base import set_default_sectioning_handler
|
||||
|
||||
__all__ = [
|
||||
"CmdTestTTS",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdTestTTS(TalemateCommand):
|
||||
"""
|
||||
@@ -22,12 +23,10 @@ class CmdTestTTS(TalemateCommand):
|
||||
|
||||
async def run(self):
|
||||
tts_agent = get_agent("tts")
|
||||
|
||||
|
||||
try:
|
||||
last_message = str(self.scene.history[-1])
|
||||
except IndexError:
|
||||
last_message = "Welcome to talemate!"
|
||||
|
||||
|
||||
|
||||
await tts_agent.generate(last_message)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import random
|
||||
|
||||
import structlog
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.scene_message import NarratorMessage
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import wait_for_input, emit
|
||||
from talemate.instance import get_agent
|
||||
import talemate.instance as instance
|
||||
from talemate.status import set_loading, LoadingStatus
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import emit, wait_for_input
|
||||
from talemate.instance import get_agent
|
||||
from talemate.scene_message import NarratorMessage
|
||||
from talemate.status import LoadingStatus, set_loading
|
||||
|
||||
log = structlog.get_logger("talemate.cmd.world_state")
|
||||
|
||||
@@ -22,6 +23,7 @@ __all__ = [
|
||||
"CmdSummarizeAndPin",
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
class CmdWorldState(TalemateCommand):
|
||||
"""
|
||||
@@ -33,282 +35,328 @@ class CmdWorldState(TalemateCommand):
|
||||
aliases = ["ws"]
|
||||
|
||||
async def run(self):
|
||||
|
||||
inline = self.args[0] == "inline" if self.args else False
|
||||
reset = self.args[0] == "reset" if self.args else False
|
||||
|
||||
|
||||
if inline:
|
||||
await self.scene.world_state.request_update_inline()
|
||||
return True
|
||||
|
||||
|
||||
if reset:
|
||||
self.scene.world_state.reset()
|
||||
|
||||
|
||||
await self.scene.world_state.request_update()
|
||||
|
||||
|
||||
@register
|
||||
class CmdPersistCharacter(TalemateCommand):
|
||||
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
|
||||
|
||||
Once persisted this character can then participate in the scene.
|
||||
"""
|
||||
|
||||
|
||||
name = "persist_character"
|
||||
description = "Persist a character by name"
|
||||
aliases = ["pc"]
|
||||
|
||||
|
||||
@set_loading("Generating character...", set_busy=False)
|
||||
async def run(self):
|
||||
from talemate.tale_mate import Character, Actor
|
||||
|
||||
from talemate.tale_mate import Actor, Character
|
||||
|
||||
scene = self.scene
|
||||
world_state = instance.get_agent("world_state")
|
||||
creator = instance.get_agent("creator")
|
||||
narrator = instance.get_agent("narrator")
|
||||
|
||||
|
||||
loading_status = LoadingStatus(3)
|
||||
|
||||
|
||||
if not len(self.args):
|
||||
characters = await world_state.identify_characters()
|
||||
available_names = [character["name"] for character in characters.get("characters") if not scene.get_character(character["name"])]
|
||||
|
||||
available_names = [
|
||||
character["name"]
|
||||
for character in characters.get("characters")
|
||||
if not scene.get_character(character["name"])
|
||||
]
|
||||
|
||||
if not len(available_names):
|
||||
raise ValueError("No characters available to persist.")
|
||||
|
||||
name = await wait_for_input("Which character would you like to persist?", data={
|
||||
"input_type": "select",
|
||||
"choices": available_names,
|
||||
"multi_select": False,
|
||||
})
|
||||
|
||||
name = await wait_for_input(
|
||||
"Which character would you like to persist?",
|
||||
data={
|
||||
"input_type": "select",
|
||||
"choices": available_names,
|
||||
"multi_select": False,
|
||||
},
|
||||
)
|
||||
else:
|
||||
name = self.args[0]
|
||||
|
||||
|
||||
extra_instructions = None
|
||||
if name == "prompt":
|
||||
name = await wait_for_input("What is the name of the character?")
|
||||
description = await wait_for_input(f"Brief description for {name} (or leave blank):")
|
||||
description = await wait_for_input(
|
||||
f"Brief description for {name} (or leave blank):"
|
||||
)
|
||||
if description.strip():
|
||||
extra_instructions = f"Name: {name}\nBrief Description: {description}"
|
||||
|
||||
|
||||
|
||||
never_narrate = len(self.args) > 1 and self.args[1] == "no"
|
||||
|
||||
|
||||
if not never_narrate:
|
||||
is_present = await world_state.is_character_present(name)
|
||||
log.debug("persist_character", name=name, is_present=is_present, never_narrate=never_narrate)
|
||||
log.debug(
|
||||
"persist_character",
|
||||
name=name,
|
||||
is_present=is_present,
|
||||
never_narrate=never_narrate,
|
||||
)
|
||||
else:
|
||||
is_present = False
|
||||
log.debug("persist_character", name=name, never_narrate=never_narrate)
|
||||
|
||||
character = Character(name=name)
|
||||
character.color = random.choice(['#F08080', '#FFD700', '#90EE90', '#ADD8E6', '#DDA0DD', '#FFB6C1', '#FAFAD2', '#D3D3D3', '#B0E0E6', '#FFDEAD'])
|
||||
|
||||
character.color = random.choice(
|
||||
[
|
||||
"#F08080",
|
||||
"#FFD700",
|
||||
"#90EE90",
|
||||
"#ADD8E6",
|
||||
"#DDA0DD",
|
||||
"#FFB6C1",
|
||||
"#FAFAD2",
|
||||
"#D3D3D3",
|
||||
"#B0E0E6",
|
||||
"#FFDEAD",
|
||||
]
|
||||
)
|
||||
|
||||
loading_status("Generating character attributes...")
|
||||
|
||||
attributes = await world_state.extract_character_sheet(name=name, text=extra_instructions)
|
||||
|
||||
attributes = await world_state.extract_character_sheet(
|
||||
name=name, text=extra_instructions
|
||||
)
|
||||
scene.log.debug("persist_character", attributes=attributes)
|
||||
|
||||
|
||||
character.base_attributes = attributes
|
||||
|
||||
|
||||
loading_status("Generating character description...")
|
||||
|
||||
|
||||
description = await creator.determine_character_description(character)
|
||||
|
||||
|
||||
character.description = description
|
||||
|
||||
|
||||
scene.log.debug("persist_character", description=description)
|
||||
|
||||
|
||||
actor = Actor(character=character, agent=instance.get_agent("conversation"))
|
||||
|
||||
|
||||
await scene.add_actor(actor)
|
||||
|
||||
emit("status", message=f"Added character {name} to the scene.", status="success")
|
||||
|
||||
emit(
|
||||
"status", message=f"Added character {name} to the scene.", status="success"
|
||||
)
|
||||
|
||||
# write narrative for the character entering the scene
|
||||
if not is_present and not never_narrate:
|
||||
loading_status("Narrating character entrance...")
|
||||
entry_narration = await narrator.narrate_character_entry(character, direction=extra_instructions)
|
||||
message = NarratorMessage(entry_narration, source=f"narrate_character_entry:{character.name}")
|
||||
entry_narration = await narrator.narrate_character_entry(
|
||||
character, direction=extra_instructions
|
||||
)
|
||||
message = NarratorMessage(
|
||||
entry_narration, source=f"narrate_character_entry:{character.name}"
|
||||
)
|
||||
self.narrator_message(message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
scene.emit_status()
|
||||
scene.world_state.emit()
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdAddReinforcement(TalemateCommand):
|
||||
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
|
||||
|
||||
Once persisted this character can then participate in the scene.
|
||||
"""
|
||||
|
||||
|
||||
name = "add_reinforcement"
|
||||
description = "Add a reinforcement to the world state"
|
||||
aliases = ["ws_ar"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
world_state = scene.world_state
|
||||
|
||||
|
||||
if not len(self.args):
|
||||
question = await wait_for_input("Ask reinforcement question")
|
||||
else:
|
||||
question = self.args[0]
|
||||
|
||||
|
||||
await world_state.add_reinforcement(question)
|
||||
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdRemoveReinforcement(TalemateCommand):
|
||||
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
|
||||
|
||||
Once persisted this character can then participate in the scene.
|
||||
"""
|
||||
|
||||
|
||||
name = "remove_reinforcement"
|
||||
description = "Remove a reinforcement from the world state"
|
||||
aliases = ["ws_rr"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
world_state = scene.world_state
|
||||
|
||||
|
||||
if not len(self.args):
|
||||
question = await wait_for_input("Ask reinforcement question")
|
||||
else:
|
||||
question = self.args[0]
|
||||
|
||||
|
||||
idx, reinforcement = await world_state.find_reinforcement(question)
|
||||
|
||||
|
||||
if idx is None:
|
||||
raise ValueError(f"Reinforcement {question} not found.")
|
||||
|
||||
|
||||
await world_state.remove_reinforcement(idx)
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdUpdateReinforcements(TalemateCommand):
|
||||
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
|
||||
|
||||
Once persisted this character can then participate in the scene.
|
||||
"""
|
||||
|
||||
|
||||
name = "update_reinforcements"
|
||||
description = "Update the reinforcements in the world state"
|
||||
aliases = ["ws_ur"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
|
||||
|
||||
await world_state.update_reinforcements(force=True)
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdCheckPinConditions(TalemateCommand):
|
||||
|
||||
|
||||
"""
|
||||
Will attempt to create an actual character from a currently non
|
||||
tracked character in the scene, by name.
|
||||
|
||||
|
||||
Once persisted this character can then participate in the scene.
|
||||
"""
|
||||
|
||||
|
||||
name = "check_pin_conditions"
|
||||
description = "Check the pin conditions in the world state"
|
||||
aliases = ["ws_cpc"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
world_state = get_agent("world_state")
|
||||
await world_state.check_pin_conditions()
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdApplyWorldStateTemplate(TalemateCommand):
|
||||
|
||||
|
||||
"""
|
||||
Will apply a world state template setting up
|
||||
automatic state tracking.
|
||||
"""
|
||||
|
||||
|
||||
name = "apply_world_state_template"
|
||||
description = "Apply a world state template, creating an auto state reinforcement."
|
||||
aliases = ["ws_awst"]
|
||||
label = "Add state"
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
if not len(self.args):
|
||||
raise ValueError("No template name provided.")
|
||||
|
||||
|
||||
template_name = self.args[0]
|
||||
template_type = self.args[1] if len(self.args) > 1 else None
|
||||
|
||||
|
||||
character_name = self.args[2] if len(self.args) > 2 else None
|
||||
|
||||
|
||||
templates = await self.scene.world_state_manager.get_templates()
|
||||
|
||||
|
||||
try:
|
||||
template = getattr(templates,template_type)[template_name]
|
||||
template = getattr(templates, template_type)[template_name]
|
||||
except KeyError:
|
||||
raise ValueError(f"Template {template_name} not found.")
|
||||
|
||||
reinforcement = await scene.world_state_manager.apply_template_state_reinforcement(
|
||||
template, character_name=character_name, run_immediately=True
|
||||
|
||||
reinforcement = (
|
||||
await scene.world_state_manager.apply_template_state_reinforcement(
|
||||
template, character_name=character_name, run_immediately=True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
response_data = {
|
||||
"template_name": template_name,
|
||||
"template_type": template_type,
|
||||
"reinforcement": reinforcement.model_dump() if reinforcement else None,
|
||||
"character_name": character_name,
|
||||
}
|
||||
|
||||
|
||||
if reinforcement is None:
|
||||
emit("status", message="State already tracked.", status="info", data=response_data)
|
||||
emit(
|
||||
"status",
|
||||
message="State already tracked.",
|
||||
status="info",
|
||||
data=response_data,
|
||||
)
|
||||
else:
|
||||
emit("status", message="Auto state added.", status="success", data=response_data)
|
||||
emit(
|
||||
"status",
|
||||
message="Auto state added.",
|
||||
status="success",
|
||||
data=response_data,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
class CmdSummarizeAndPin(TalemateCommand):
|
||||
|
||||
|
||||
"""
|
||||
Will take a message index and then walk back N messages
|
||||
summarizing the scene and pinning it to the context.
|
||||
"""
|
||||
|
||||
|
||||
name = "summarize_and_pin"
|
||||
label = "Summarize and pin"
|
||||
description = "Summarize a snapshot of the scene and pin it to the world state"
|
||||
aliases = ["ws_sap"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
|
||||
scene = self.scene
|
||||
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
|
||||
|
||||
if not self.scene.history:
|
||||
raise ValueError("No history to summarize.")
|
||||
|
||||
|
||||
message_id = int(self.args[0]) if len(self.args) else scene.history[-1].id
|
||||
num_messages = int(self.args[1]) if len(self.args) > 1 else 5
|
||||
|
||||
await world_state.summarize_and_pin(message_id, num_messages=num_messages)
|
||||
|
||||
await world_state.summarize_and_pin(message_id, num_messages=num_messages)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from talemate.emit import Emitter, AbortCommand
|
||||
import structlog
|
||||
|
||||
from talemate.emit import AbortCommand, Emitter
|
||||
|
||||
log = structlog.get_logger("talemate.commands.manager")
|
||||
|
||||
|
||||
class Manager(Emitter):
|
||||
"""
|
||||
TaleMateCommand class to handle user command
|
||||
@@ -38,7 +40,7 @@ class Manager(Emitter):
|
||||
cmd_args = ""
|
||||
if not self.is_command(cmd):
|
||||
return False
|
||||
|
||||
|
||||
if ":" in cmd:
|
||||
# split command name and args which are separated by a colon
|
||||
cmd_name, cmd_args = cmd[1:].split(":", 1)
|
||||
@@ -46,7 +48,7 @@ class Manager(Emitter):
|
||||
else:
|
||||
cmd_name = cmd[1:]
|
||||
cmd_args = []
|
||||
|
||||
|
||||
for command_cls in self.command_classes:
|
||||
if command_cls.is_command(cmd_name):
|
||||
command = command_cls(self, *cmd_args)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import yaml
|
||||
import datetime
|
||||
import os
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import os
|
||||
import datetime
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Union, ClassVar, TYPE_CHECKING
|
||||
|
||||
from talemate.emit import emit
|
||||
from talemate.scene_assets import Asset
|
||||
@@ -15,40 +15,44 @@ if TYPE_CHECKING:
|
||||
|
||||
log = structlog.get_logger("talemate.config")
|
||||
|
||||
|
||||
class Client(BaseModel):
|
||||
type: str
|
||||
name: str
|
||||
model: Union[str,None] = None
|
||||
api_url: Union[str,None] = None
|
||||
api_key: Union[str,None] = None
|
||||
max_token_length: Union[int,None] = None
|
||||
|
||||
model: Union[str, None] = None
|
||||
api_url: Union[str, None] = None
|
||||
api_key: Union[str, None] = None
|
||||
max_token_length: int = 4096
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
|
||||
class AgentActionConfig(BaseModel):
|
||||
value: Union[int, float, str, bool, None] = None
|
||||
|
||||
|
||||
|
||||
class AgentAction(BaseModel):
|
||||
enabled: bool = True
|
||||
config: Union[dict[str, AgentActionConfig], None] = None
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
name: Union[str,None] = None
|
||||
client: Union[str,None] = None
|
||||
name: Union[str, None] = None
|
||||
client: Union[str, None] = None
|
||||
actions: Union[dict[str, AgentAction], None] = None
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
# change serialization so actions and enabled are only
|
||||
# serialized if they are not None
|
||||
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
return super().model_dump(exclude_none=True)
|
||||
|
||||
|
||||
class GamePlayerCharacter(BaseModel):
|
||||
name: str = ""
|
||||
color: str = "#3362bb"
|
||||
@@ -58,10 +62,12 @@ class GamePlayerCharacter(BaseModel):
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
class General(BaseModel):
|
||||
auto_save: bool = True
|
||||
auto_progress: bool = True
|
||||
|
||||
|
||||
class StateReinforcementTemplate(BaseModel):
|
||||
name: str
|
||||
query: str
|
||||
@@ -72,52 +78,68 @@ class StateReinforcementTemplate(BaseModel):
|
||||
interval: int = 10
|
||||
auto_create: bool = False
|
||||
favorite: bool = False
|
||||
|
||||
type:ClassVar = "state_reinforcement"
|
||||
|
||||
|
||||
type: ClassVar = "state_reinforcement"
|
||||
|
||||
|
||||
class WorldStateTemplates(BaseModel):
|
||||
state_reinforcement: dict[str, StateReinforcementTemplate] = pydantic.Field(default_factory=dict)
|
||||
|
||||
state_reinforcement: dict[str, StateReinforcementTemplate] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
class WorldState(BaseModel):
|
||||
templates: WorldStateTemplates = WorldStateTemplates()
|
||||
|
||||
templates: WorldStateTemplates = WorldStateTemplates()
|
||||
|
||||
|
||||
class Game(BaseModel):
|
||||
default_player_character: GamePlayerCharacter = GamePlayerCharacter()
|
||||
general: General = General()
|
||||
world_state: WorldState = WorldState()
|
||||
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
class CreatorConfig(BaseModel):
|
||||
content_context: list[str] = ["a fun and engaging slice of life story aimed at an adult audience."]
|
||||
content_context: list[str] = [
|
||||
"a fun and engaging slice of life story aimed at an adult audience."
|
||||
]
|
||||
|
||||
|
||||
class OpenAIConfig(BaseModel):
|
||||
api_key: Union[str,None]=None
|
||||
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class RunPodConfig(BaseModel):
|
||||
api_key: Union[str,None]=None
|
||||
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class ElevenLabsConfig(BaseModel):
|
||||
api_key: Union[str,None]=None
|
||||
api_key: Union[str, None] = None
|
||||
model: str = "eleven_turbo_v2"
|
||||
|
||||
|
||||
|
||||
class CoquiConfig(BaseModel):
|
||||
api_key: Union[str,None]=None
|
||||
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class TTSVoiceSamples(BaseModel):
|
||||
label:str
|
||||
value:str
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
class TTSConfig(BaseModel):
|
||||
device:str = "cuda"
|
||||
model:str = "tts_models/multilingual/multi-dataset/xtts_v2"
|
||||
device: str = "cuda"
|
||||
model: str = "tts_models/multilingual/multi-dataset/xtts_v2"
|
||||
voices: list[TTSVoiceSamples] = pydantic.Field(default_factory=list)
|
||||
|
||||
|
||||
class ChromaDB(BaseModel):
|
||||
instructor_device: str="cpu"
|
||||
instructor_model: str="default"
|
||||
embeddings: str="default"
|
||||
instructor_device: str = "cpu"
|
||||
instructor_model: str = "default"
|
||||
embeddings: str = "default"
|
||||
|
||||
|
||||
class RecentScene(BaseModel):
|
||||
name: str
|
||||
@@ -125,43 +147,48 @@ class RecentScene(BaseModel):
|
||||
filename: str
|
||||
date: str
|
||||
cover_image: Union[Asset, None] = None
|
||||
|
||||
|
||||
|
||||
class RecentScenes(BaseModel):
|
||||
scenes: list[RecentScene] = pydantic.Field(default_factory=list)
|
||||
max_entries: int = 10
|
||||
|
||||
def push(self, scene:"Scene"):
|
||||
def push(self, scene: "Scene"):
|
||||
"""
|
||||
adds a scene to the recent scenes list
|
||||
"""
|
||||
|
||||
|
||||
# if scene has not been saved, don't add it
|
||||
if not scene.full_path:
|
||||
return
|
||||
|
||||
|
||||
now = datetime.datetime.now()
|
||||
|
||||
|
||||
# remove any existing entries for this scene
|
||||
self.scenes = [s for s in self.scenes if s.path != scene.full_path]
|
||||
|
||||
|
||||
# add the new entry
|
||||
self.scenes.insert(0,
|
||||
self.scenes.insert(
|
||||
0,
|
||||
RecentScene(
|
||||
name=scene.name,
|
||||
path=scene.full_path,
|
||||
name=scene.name,
|
||||
path=scene.full_path,
|
||||
filename=scene.filename,
|
||||
date=now.isoformat(),
|
||||
cover_image=scene.assets.assets[scene.assets.cover_image] if scene.assets.cover_image else None
|
||||
))
|
||||
|
||||
date=now.isoformat(),
|
||||
cover_image=scene.assets.assets[scene.assets.cover_image]
|
||||
if scene.assets.cover_image
|
||||
else None,
|
||||
),
|
||||
)
|
||||
|
||||
# trim the list to max_entries
|
||||
self.scenes = self.scenes[:self.max_entries]
|
||||
|
||||
self.scenes = self.scenes[: self.max_entries]
|
||||
|
||||
def clean(self):
|
||||
"""
|
||||
removes any entries that no longer exist
|
||||
"""
|
||||
|
||||
|
||||
self.scenes = [s for s in self.scenes if os.path.exists(s.path)]
|
||||
|
||||
|
||||
@@ -170,46 +197,50 @@ class Config(BaseModel):
|
||||
game: Game
|
||||
|
||||
agents: Dict[str, Agent] = {}
|
||||
|
||||
|
||||
creator: CreatorConfig = CreatorConfig()
|
||||
|
||||
|
||||
openai: OpenAIConfig = OpenAIConfig()
|
||||
|
||||
|
||||
runpod: RunPodConfig = RunPodConfig()
|
||||
|
||||
|
||||
chromadb: ChromaDB = ChromaDB()
|
||||
|
||||
|
||||
elevenlabs: ElevenLabsConfig = ElevenLabsConfig()
|
||||
|
||||
|
||||
coqui: CoquiConfig = CoquiConfig()
|
||||
|
||||
|
||||
tts: TTSConfig = TTSConfig()
|
||||
|
||||
|
||||
recent_scenes: RecentScenes = RecentScenes()
|
||||
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
def save(self, file_path: str = "./config.yaml"):
|
||||
save_config(self, file_path)
|
||||
|
||||
|
||||
class SceneConfig(BaseModel):
|
||||
automated_actions: dict[str, bool]
|
||||
|
||||
class SceneAssetUpload(BaseModel):
|
||||
scene_cover_image:bool
|
||||
character_cover_image:str = None
|
||||
content:str = None
|
||||
|
||||
|
||||
def load_config(file_path: str = "./config.yaml", as_model:bool=False) -> Union[dict, Config]:
|
||||
|
||||
class SceneAssetUpload(BaseModel):
|
||||
scene_cover_image: bool
|
||||
character_cover_image: str = None
|
||||
content: str = None
|
||||
|
||||
|
||||
def load_config(
|
||||
file_path: str = "./config.yaml", as_model: bool = False
|
||||
) -> Union[dict, Config]:
|
||||
"""
|
||||
Load the config file from the given path.
|
||||
|
||||
|
||||
Should cache the config and only reload if the file modification time
|
||||
has changed since the last load
|
||||
"""
|
||||
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
@@ -225,13 +256,14 @@ def load_config(file_path: str = "./config.yaml", as_model:bool=False) -> Union[
|
||||
|
||||
return config.model_dump()
|
||||
|
||||
|
||||
def save_config(config, file_path: str = "./config.yaml"):
|
||||
"""
|
||||
Save the config file to the given path.
|
||||
"""
|
||||
|
||||
|
||||
log.debug("Saving config", file_path=file_path)
|
||||
|
||||
|
||||
# If config is a Config instance, convert it to a dictionary
|
||||
if isinstance(config, Config):
|
||||
config = config.model_dump(exclude_none=True)
|
||||
@@ -245,5 +277,5 @@ def save_config(config, file_path: str = "./config.yaml"):
|
||||
|
||||
with open(file_path, "w") as file:
|
||||
yaml.dump(config, file)
|
||||
|
||||
emit("config_saved", data=config)
|
||||
|
||||
emit("config_saved", data=config)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
import structlog
|
||||
|
||||
__all__ = [
|
||||
@@ -13,29 +14,34 @@ log = structlog.get_logger(__name__)
|
||||
scene_is_loading = ContextVar("scene_is_loading", default=None)
|
||||
rerun_context = ContextVar("rerun_context", default=None)
|
||||
|
||||
|
||||
class SceneIsLoading:
|
||||
|
||||
def __init__(self, scene):
|
||||
self.scene = scene
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.token = scene_is_loading.set(self.scene)
|
||||
|
||||
|
||||
def __exit__(self, *args):
|
||||
scene_is_loading.reset(self.token)
|
||||
|
||||
|
||||
class RerunContext:
|
||||
|
||||
def __init__(self, scene, direction=None, method="replace", message:str = None):
|
||||
def __init__(self, scene, direction=None, method="replace", message: str = None):
|
||||
self.scene = scene
|
||||
self.direction = direction
|
||||
self.method = method
|
||||
self.message = message
|
||||
log.debug("RerunContext", scene=scene, direction=direction, method=method, message=message)
|
||||
|
||||
log.debug(
|
||||
"RerunContext",
|
||||
scene=scene,
|
||||
direction=direction,
|
||||
method=method,
|
||||
message=message,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.token = rerun_context.set(self)
|
||||
|
||||
|
||||
def __exit__(self, *args):
|
||||
rerun_context.reset(self.token)
|
||||
rerun_context.reset(self.token)
|
||||
|
||||
@@ -4,9 +4,10 @@ __all__ = [
|
||||
"ArchiveEntry",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchiveEntry:
|
||||
text: str
|
||||
start: int = None
|
||||
end: int = None
|
||||
ts: str = None
|
||||
ts: str = None
|
||||
|
||||
@@ -1,57 +1,56 @@
|
||||
handlers = {
|
||||
}
|
||||
handlers = {}
|
||||
|
||||
|
||||
class AsyncSignal:
|
||||
|
||||
def __init__(self, name):
|
||||
self.receivers = []
|
||||
self.name = name
|
||||
|
||||
|
||||
def connect(self, handler):
|
||||
if handler in self.receivers:
|
||||
return
|
||||
self.receivers.append(handler)
|
||||
|
||||
|
||||
def disconnect(self, handler):
|
||||
self.receivers.remove(handler)
|
||||
|
||||
|
||||
async def send(self, emission):
|
||||
for receiver in self.receivers:
|
||||
await receiver(emission)
|
||||
|
||||
|
||||
def _register(name:str):
|
||||
|
||||
def _register(name: str):
|
||||
"""
|
||||
Registers a signal handler
|
||||
|
||||
|
||||
Arguments:
|
||||
name (str): The name of the signal
|
||||
handler (signal): The signal handler
|
||||
"""
|
||||
|
||||
|
||||
if name in handlers:
|
||||
raise ValueError(f"Signal {name} already registered")
|
||||
|
||||
|
||||
handlers[name] = AsyncSignal(name)
|
||||
return handlers[name]
|
||||
|
||||
|
||||
|
||||
def register(*names):
|
||||
"""
|
||||
Registers many signal handlers
|
||||
|
||||
|
||||
Arguments:
|
||||
*names (str): The names of the signals
|
||||
"""
|
||||
for name in names:
|
||||
_register(name)
|
||||
|
||||
|
||||
def get(name:str):
|
||||
|
||||
|
||||
def get(name: str):
|
||||
"""
|
||||
Gets a signal handler
|
||||
|
||||
|
||||
Arguments:
|
||||
name (str): The name of the signal handler
|
||||
"""
|
||||
return handlers.get(name)
|
||||
return handlers.get(name)
|
||||
|
||||
@@ -2,13 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .signals import handlers
|
||||
import structlog
|
||||
|
||||
from talemate.scene_message import SceneMessage
|
||||
|
||||
from .signals import handlers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character, Scene
|
||||
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.emit.base")
|
||||
|
||||
|
||||
class AbortCommand(IOError):
|
||||
pass
|
||||
|
||||
@@ -39,12 +41,15 @@ class Emission:
|
||||
|
||||
|
||||
def emit(
|
||||
typ: str, message: str = None, character: Character = None, scene: Scene = None, **kwargs
|
||||
typ: str,
|
||||
message: str = None,
|
||||
character: Character = None,
|
||||
scene: Scene = None,
|
||||
**kwargs,
|
||||
):
|
||||
if typ not in handlers:
|
||||
raise ValueError(f"Unknown message type: {typ}")
|
||||
|
||||
|
||||
|
||||
if isinstance(message, SceneMessage):
|
||||
kwargs["id"] = message.id
|
||||
message_object = message
|
||||
@@ -53,7 +58,14 @@ def emit(
|
||||
message_object = None
|
||||
|
||||
handlers[typ].send(
|
||||
Emission(typ=typ, message=message, character=character, scene=scene, message_object=message_object, **kwargs)
|
||||
Emission(
|
||||
typ=typ,
|
||||
message=message,
|
||||
character=character,
|
||||
scene=scene,
|
||||
message_object=message_object,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -80,7 +92,6 @@ async def wait_for_input(
|
||||
def input_receiver(emission: Emission):
|
||||
input_received["message"] = emission.message
|
||||
|
||||
|
||||
handlers["receive_input"].connect(input_receiver)
|
||||
|
||||
handlers["request_input"].send(
|
||||
@@ -97,7 +108,7 @@ async def wait_for_input(
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
handlers["receive_input"].disconnect(input_receiver)
|
||||
|
||||
|
||||
if input_received["message"] == "!abort":
|
||||
raise AbortCommand()
|
||||
|
||||
@@ -145,4 +156,4 @@ class Emitter:
|
||||
self.emit("character", message, character=character)
|
||||
|
||||
def player_message(self, message: str, character: Character):
|
||||
self.emit("player", message, character=character)
|
||||
self.emit("player", message, character=character)
|
||||
|
||||
@@ -18,7 +18,7 @@ ClientStatus = signal("client_status")
|
||||
RequestClientStatus = signal("request_client_status")
|
||||
AgentStatus = signal("agent_status")
|
||||
RequestAgentStatus = signal("request_agent_status")
|
||||
ClientBootstraps = signal("client_bootstraps")
|
||||
ClientBootstraps = signal("client_bootstraps")
|
||||
PromptSent = signal("prompt_sent")
|
||||
|
||||
RemoveMessage = signal("remove_message")
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene, Actor, SceneMessage
|
||||
from talemate.tale_mate import Actor, Scene, SceneMessage
|
||||
|
||||
__all__ = [
|
||||
"Event",
|
||||
@@ -35,27 +35,33 @@ class CharacterStateEvent(Event):
|
||||
state: str
|
||||
character_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneStateEvent(Event):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameLoopBase(Event):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameLoopEvent(GameLoopBase):
|
||||
had_passive_narration: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameLoopStartEvent(GameLoopBase):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameLoopActorIterEvent(GameLoopBase):
|
||||
actor: Actor
|
||||
game_loop: GameLoopEvent
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameLoopNewMessageEvent(GameLoopBase):
|
||||
message: SceneMessage
|
||||
message: SceneMessage
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
class TalemateError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TalemateInterrupt(Exception):
|
||||
"""
|
||||
Exception to interrupt the game loop
|
||||
@@ -8,6 +9,7 @@ class TalemateInterrupt(Exception):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExitScene(TalemateInterrupt):
|
||||
"""
|
||||
Exception to exit the scene
|
||||
@@ -15,18 +17,20 @@ class ExitScene(TalemateInterrupt):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RestartSceneLoop(TalemateInterrupt):
|
||||
"""
|
||||
Exception to switch the scene loop
|
||||
"""
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResetScene(TalemateInterrupt):
|
||||
"""
|
||||
Exception to reset the scene
|
||||
"""
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -34,7 +38,7 @@ class RenderPromptError(TalemateError):
|
||||
"""
|
||||
Exception to raise when there is an error rendering a prompt
|
||||
"""
|
||||
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -42,11 +46,10 @@ class LLMAccuracyError(TalemateError):
|
||||
"""
|
||||
Exception to raise when the LLM response is not processable
|
||||
"""
|
||||
|
||||
def __init__(self, message:str, model_name:str=None):
|
||||
|
||||
|
||||
def __init__(self, message: str, model_name: str = None):
|
||||
if model_name:
|
||||
message = f"{model_name} - {message}"
|
||||
|
||||
|
||||
super().__init__(message)
|
||||
self.model_name = model_name
|
||||
self.model_name = model_name
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
import fnmatch
|
||||
import os
|
||||
|
||||
from talemate.config import load_config
|
||||
|
||||
@@ -27,7 +27,7 @@ def _list_files_and_directories(root: str, path: str) -> list:
|
||||
:return: List of files and directories in the given root directory.
|
||||
"""
|
||||
# Define the file patterns to match
|
||||
patterns = ['characters/*.png', 'characters/*.webp', '*/*.json']
|
||||
patterns = ["characters/*.png", "characters/*.webp", "*/*.json"]
|
||||
|
||||
items = []
|
||||
|
||||
@@ -42,4 +42,4 @@ def _list_files_and_directories(root: str, path: str) -> list:
|
||||
items.append(os.path.join(dirpath, filename))
|
||||
break
|
||||
|
||||
return items
|
||||
return items
|
||||
|
||||
@@ -1,56 +1,62 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import nest_asyncio
|
||||
import pydantic
|
||||
import structlog
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
from talemate.prompts.base import Prompt, PrependTemplateDirectories
|
||||
from talemate.instance import get_agent
|
||||
|
||||
from talemate.agents.director import DirectorAgent
|
||||
from talemate.agents.memory import MemoryAgent
|
||||
from talemate.instance import get_agent
|
||||
from talemate.prompts.base import PrependTemplateDirectories, Prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
log = structlog.get_logger("game_state")
|
||||
|
||||
|
||||
class Goal(pydantic.BaseModel):
|
||||
description: str
|
||||
id: int
|
||||
status: bool = False
|
||||
|
||||
|
||||
|
||||
class Instructions(pydantic.BaseModel):
|
||||
character: dict[str, str] = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
|
||||
class Ops(pydantic.BaseModel):
|
||||
run_on_start: bool = False
|
||||
|
||||
|
||||
|
||||
class GameState(pydantic.BaseModel):
|
||||
ops: Ops = Ops()
|
||||
variables: dict[str,Any] = pydantic.Field(default_factory=dict)
|
||||
variables: dict[str, Any] = pydantic.Field(default_factory=dict)
|
||||
goals: list[Goal] = pydantic.Field(default_factory=list)
|
||||
instructions: Instructions = pydantic.Field(default_factory=Instructions)
|
||||
|
||||
|
||||
@property
|
||||
def director(self) -> DirectorAgent:
|
||||
return get_agent('director')
|
||||
|
||||
return get_agent("director")
|
||||
|
||||
@property
|
||||
def memory(self) -> MemoryAgent:
|
||||
return get_agent('memory')
|
||||
|
||||
return get_agent("memory")
|
||||
|
||||
@property
|
||||
def scene(self) -> 'Scene':
|
||||
def scene(self) -> "Scene":
|
||||
return self.director.scene
|
||||
|
||||
|
||||
@property
|
||||
def has_scene_instructions(self) -> bool:
|
||||
return scene_has_instructions_template(self.scene)
|
||||
|
||||
|
||||
@property
|
||||
def game_won(self) -> bool:
|
||||
return self.variables.get("__game_won__") == True
|
||||
|
||||
|
||||
@property
|
||||
def scene_instructions(self) -> str:
|
||||
scene = self.scene
|
||||
@@ -59,43 +65,52 @@ class GameState(pydantic.BaseModel):
|
||||
game_state = self
|
||||
if scene_has_instructions_template(self.scene):
|
||||
with PrependTemplateDirectories([scene.template_dir]):
|
||||
prompt = Prompt.get('instructions', {
|
||||
'scene': scene,
|
||||
'max_tokens': client.max_token_length,
|
||||
'game_state': game_state
|
||||
})
|
||||
|
||||
prompt = Prompt.get(
|
||||
"instructions",
|
||||
{
|
||||
"scene": scene,
|
||||
"max_tokens": client.max_token_length,
|
||||
"game_state": game_state,
|
||||
},
|
||||
)
|
||||
|
||||
prompt.client = client
|
||||
instructions = prompt.render().strip()
|
||||
log.info("Initialized game state instructions", scene=scene, instructions=instructions)
|
||||
log.info(
|
||||
"Initialized game state instructions",
|
||||
scene=scene,
|
||||
instructions=instructions,
|
||||
)
|
||||
return instructions
|
||||
|
||||
def init(self, scene: 'Scene') -> 'GameState':
|
||||
|
||||
def init(self, scene: "Scene") -> "GameState":
|
||||
return self
|
||||
|
||||
|
||||
def set_var(self, key: str, value: Any, commit: bool = False):
|
||||
self.variables[key] = value
|
||||
if commit:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.memory.add(value, uid=f"game_state.{key}"))
|
||||
|
||||
|
||||
def has_var(self, key: str) -> bool:
|
||||
return key in self.variables
|
||||
|
||||
|
||||
def get_var(self, key: str) -> Any:
|
||||
return self.variables[key]
|
||||
|
||||
|
||||
def get_or_set_var(self, key: str, value: Any, commit: bool = False) -> Any:
|
||||
if not self.has_var(key):
|
||||
self.set_var(key, value, commit=commit)
|
||||
return self.get_var(key)
|
||||
|
||||
def scene_has_game_template(scene: 'Scene') -> bool:
|
||||
|
||||
def scene_has_game_template(scene: "Scene") -> bool:
|
||||
"""Returns True if the scene has a game template."""
|
||||
game_template_path = os.path.join(scene.template_dir, 'game.jinja2')
|
||||
game_template_path = os.path.join(scene.template_dir, "game.jinja2")
|
||||
return os.path.exists(game_template_path)
|
||||
|
||||
def scene_has_instructions_template(scene: 'Scene') -> bool:
|
||||
|
||||
def scene_has_instructions_template(scene: "Scene") -> bool:
|
||||
"""Returns True if the scene has an instructions template."""
|
||||
instructions_template_path = os.path.join(scene.template_dir, 'instructions.jinja2')
|
||||
instructions_template_path = os.path.join(scene.template_dir, "instructions.jinja2")
|
||||
return os.path.exists(instructions_template_path)
|
||||
|
||||
@@ -2,21 +2,21 @@
|
||||
Keep track of clients and agents
|
||||
"""
|
||||
import asyncio
|
||||
import talemate.agents as agents
|
||||
import talemate.client as clients
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
import talemate.client.bootstrap as bootstrap
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.agents as agents
|
||||
import talemate.client as clients
|
||||
import talemate.client.bootstrap as bootstrap
|
||||
from talemate.emit import emit
|
||||
from talemate.emit.signals import handlers
|
||||
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
AGENTS = {}
|
||||
CLIENTS = {}
|
||||
|
||||
|
||||
|
||||
|
||||
def get_agent(typ: str, *create_args, **create_kwargs):
|
||||
agent = AGENTS.get(typ)
|
||||
|
||||
@@ -75,48 +75,51 @@ def client_instances():
|
||||
def agent_instances():
|
||||
return AGENTS.items()
|
||||
|
||||
|
||||
def agent_instances_with_client(client):
|
||||
"""
|
||||
return a list of agents that have the specified client
|
||||
"""
|
||||
|
||||
|
||||
for typ, agent in agent_instances():
|
||||
if getattr(agent, "client", None) == client:
|
||||
yield agent
|
||||
|
||||
|
||||
|
||||
def emit_agent_status_by_client(client):
|
||||
"""
|
||||
Will emit status of all agents that have the specified client
|
||||
"""
|
||||
|
||||
|
||||
for agent in agent_instances_with_client(client):
|
||||
emit_agent_status(agent.__class__, agent)
|
||||
|
||||
|
||||
|
||||
async def emit_clients_status():
|
||||
"""
|
||||
Will emit status of all clients
|
||||
"""
|
||||
#log.debug("emit", type="client status")
|
||||
# log.debug("emit", type="client status")
|
||||
for client in CLIENTS.values():
|
||||
if client:
|
||||
await client.status()
|
||||
|
||||
|
||||
def _sync_emit_clients_status(*args, **kwargs):
|
||||
"""
|
||||
Will emit status of all clients
|
||||
in synchronous mode
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(emit_clients_status())
|
||||
loop.run_until_complete(emit_clients_status())
|
||||
|
||||
|
||||
handlers["request_client_status"].connect(_sync_emit_clients_status)
|
||||
|
||||
|
||||
async def emit_client_bootstraps():
|
||||
emit(
|
||||
"client_bootstraps",
|
||||
data=list(await bootstrap.list_all())
|
||||
)
|
||||
emit("client_bootstraps", data=list(await bootstrap.list_all()))
|
||||
|
||||
|
||||
def sync_emit_clients_status():
|
||||
"""
|
||||
@@ -126,15 +129,20 @@ def sync_emit_clients_status():
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(emit_clients_status())
|
||||
|
||||
|
||||
async def sync_client_bootstraps():
|
||||
"""
|
||||
Will loop through all registered client bootstrap lists and spawn / update
|
||||
Will loop through all registered client bootstrap lists and spawn / update
|
||||
client instances from them.
|
||||
"""
|
||||
|
||||
|
||||
for service_name, func in bootstrap.LISTS.items():
|
||||
async for client_bootstrap in func():
|
||||
log.debug("sync client bootstrap", service_name=service_name, client_bootstrap=client_bootstrap.dict())
|
||||
log.debug(
|
||||
"sync client bootstrap",
|
||||
service_name=service_name,
|
||||
client_bootstrap=client_bootstrap.dict(),
|
||||
)
|
||||
client = get_client(
|
||||
client_bootstrap.name,
|
||||
type=client_bootstrap.client_type.value,
|
||||
@@ -143,6 +151,7 @@ async def sync_client_bootstraps():
|
||||
)
|
||||
await client.status()
|
||||
|
||||
|
||||
def emit_agent_status(cls, agent=None):
|
||||
if not agent:
|
||||
emit(
|
||||
@@ -167,9 +176,10 @@ def emit_agents_status(*args, **kwargs):
|
||||
"""
|
||||
Will emit status of all agents
|
||||
"""
|
||||
#log.debug("emit", type="agent status")
|
||||
# log.debug("emit", type="agent status")
|
||||
for typ, cls in agents.AGENT_CLASSES.items():
|
||||
agent = AGENTS.get(typ)
|
||||
emit_agent_status(cls, agent)
|
||||
|
||||
handlers["request_agent_status"].connect(emit_agents_status)
|
||||
|
||||
|
||||
handlers["request_agent_status"].connect(emit_agents_status)
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import structlog
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import talemate.events as events
|
||||
import talemate.instance as instance
|
||||
from talemate import Actor, Character, Player
|
||||
from talemate.config import load_config
|
||||
from talemate.scene_message import (
|
||||
SceneMessage, CharacterMessage, NarratorMessage, DirectorMessage, MESSAGES, reset_message_id
|
||||
)
|
||||
from talemate.world_state import WorldState
|
||||
from talemate.game_state import GameState
|
||||
from talemate.context import SceneIsLoading
|
||||
from talemate.emit import emit
|
||||
from talemate.status import set_loading, LoadingStatus
|
||||
import talemate.instance as instance
|
||||
|
||||
import structlog
|
||||
from talemate.game_state import GameState
|
||||
from talemate.scene_message import (
|
||||
MESSAGES,
|
||||
CharacterMessage,
|
||||
DirectorMessage,
|
||||
NarratorMessage,
|
||||
SceneMessage,
|
||||
reset_message_id,
|
||||
)
|
||||
from talemate.status import LoadingStatus, set_loading
|
||||
from talemate.world_state import WorldState
|
||||
|
||||
__all__ = [
|
||||
"load_scene",
|
||||
@@ -29,6 +33,7 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.load")
|
||||
|
||||
|
||||
@set_loading("Loading scene...")
|
||||
async def load_scene(scene, file_path, conv_client, reset: bool = False):
|
||||
"""
|
||||
@@ -61,8 +66,7 @@ async def load_scene_from_character_card(scene, file_path):
|
||||
"""
|
||||
Load a character card (tavern etc.) from the given file path.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
loading_status = LoadingStatus(5)
|
||||
loading_status("Loading character card...")
|
||||
|
||||
@@ -85,59 +89,68 @@ async def load_scene_from_character_card(scene, file_path):
|
||||
actor = Actor(character, conversation)
|
||||
|
||||
scene.name = character.name
|
||||
|
||||
|
||||
loading_status("Initializing long-term memory...")
|
||||
|
||||
|
||||
await memory.set_db()
|
||||
|
||||
await scene.add_actor(actor)
|
||||
|
||||
|
||||
log.debug("load_scene_from_character_card", scene=scene, character=character, content_context=scene.context)
|
||||
|
||||
|
||||
log.debug(
|
||||
"load_scene_from_character_card",
|
||||
scene=scene,
|
||||
character=character,
|
||||
content_context=scene.context,
|
||||
)
|
||||
|
||||
loading_status("Determine character context...")
|
||||
|
||||
|
||||
if not scene.context:
|
||||
try:
|
||||
scene.context = await creator.determine_content_context_for_character(character)
|
||||
scene.context = await creator.determine_content_context_for_character(
|
||||
character
|
||||
)
|
||||
log.debug("content_context", content_context=scene.context)
|
||||
except Exception as e:
|
||||
log.error("determine_content_context_for_character", error=e)
|
||||
|
||||
|
||||
# attempt to convert to base attributes
|
||||
try:
|
||||
|
||||
loading_status("Determine character attributes...")
|
||||
|
||||
_, character.base_attributes = await creator.determine_character_attributes(character)
|
||||
|
||||
_, character.base_attributes = await creator.determine_character_attributes(
|
||||
character
|
||||
)
|
||||
# lowercase keys
|
||||
character.base_attributes = {k.lower(): v for k, v in character.base_attributes.items()}
|
||||
|
||||
character.base_attributes = {
|
||||
k.lower(): v for k, v in character.base_attributes.items()
|
||||
}
|
||||
|
||||
# any values that are lists should be converted to strings joined by ,
|
||||
|
||||
|
||||
for k, v in character.base_attributes.items():
|
||||
if isinstance(v, list):
|
||||
character.base_attributes[k] = ",".join(v)
|
||||
|
||||
|
||||
# transfer description to character
|
||||
if character.base_attributes.get("description"):
|
||||
character.description = character.base_attributes.pop("description")
|
||||
|
||||
|
||||
await character.commit_to_memory(scene.get_helper("memory").agent)
|
||||
|
||||
|
||||
log.debug("base_attributes parsed", base_attributes=character.base_attributes)
|
||||
except Exception as e:
|
||||
log.warning("determine_character_attributes", error=e)
|
||||
|
||||
|
||||
scene.description = character.description
|
||||
|
||||
|
||||
if image:
|
||||
scene.assets.set_cover_image_from_file_path(file_path)
|
||||
character.cover_image = scene.assets.cover_image
|
||||
|
||||
|
||||
try:
|
||||
loading_status("Update world state ...")
|
||||
await scene.world_state.request_update(initial_only=True)
|
||||
await scene.world_state.request_update(initial_only=True)
|
||||
except Exception as e:
|
||||
log.error("world_state.request_update", error=e)
|
||||
|
||||
@@ -151,9 +164,9 @@ async def load_scene_from_data(
|
||||
):
|
||||
loading_status = LoadingStatus(1)
|
||||
reset_message_id()
|
||||
|
||||
|
||||
memory = scene.get_helper("memory").agent
|
||||
|
||||
|
||||
scene.description = scene_data.get("description", "")
|
||||
scene.intro = scene_data.get("intro", "") or scene.description
|
||||
scene.name = scene_data.get("name", "Unknown Scene")
|
||||
@@ -161,11 +174,10 @@ async def load_scene_from_data(
|
||||
scene.filename = None
|
||||
scene.goals = scene_data.get("goals", [])
|
||||
scene.immutable_save = scene_data.get("immutable_save", False)
|
||||
|
||||
#reset = True
|
||||
|
||||
|
||||
# reset = True
|
||||
|
||||
if not reset:
|
||||
|
||||
scene.goal = scene_data.get("goal", 0)
|
||||
scene.memory_id = scene_data.get("memory_id", scene.memory_id)
|
||||
scene.saved_memory_session_id = scene_data.get("saved_memory_session_id", None)
|
||||
@@ -181,33 +193,37 @@ async def load_scene_from_data(
|
||||
)
|
||||
scene.assets.cover_image = scene_data.get("assets", {}).get("cover_image", None)
|
||||
scene.assets.load_assets(scene_data.get("assets", {}).get("assets", {}))
|
||||
|
||||
|
||||
scene.sync_time()
|
||||
log.debug("scene time", ts=scene.ts)
|
||||
|
||||
|
||||
loading_status("Initializing long-term memory...")
|
||||
|
||||
|
||||
await memory.set_db()
|
||||
await memory.remove_unsaved_memory()
|
||||
|
||||
|
||||
await scene.world_state_manager.remove_all_empty_pins()
|
||||
|
||||
|
||||
if not scene.memory_session_id:
|
||||
scene.set_new_memory_session_id()
|
||||
|
||||
|
||||
for ah in scene.archived_history:
|
||||
if reset:
|
||||
break
|
||||
ts = ah.get("ts", "PT1S")
|
||||
|
||||
|
||||
if not ah.get("ts"):
|
||||
ah["ts"] = ts
|
||||
|
||||
|
||||
scene.signals["archive_add"].send(
|
||||
events.ArchiveEvent(scene=scene, event_type="archive_add", text=ah["text"], ts=ts)
|
||||
events.ArchiveEvent(
|
||||
scene=scene, event_type="archive_add", text=ah["text"], ts=ts
|
||||
)
|
||||
)
|
||||
|
||||
for character_name, character_data in scene_data.get("inactive_characters", {}).items():
|
||||
for character_name, character_data in scene_data.get(
|
||||
"inactive_characters", {}
|
||||
).items():
|
||||
scene.inactive_characters[character_name] = Character(**character_data)
|
||||
|
||||
for character_name, cs in scene.character_states.items():
|
||||
@@ -215,10 +231,10 @@ async def load_scene_from_data(
|
||||
|
||||
for character_data in scene_data["characters"]:
|
||||
character = Character(**character_data)
|
||||
|
||||
|
||||
if character.name in scene.inactive_characters:
|
||||
scene.inactive_characters.pop(character.name)
|
||||
|
||||
|
||||
if not character.is_player:
|
||||
agent = instance.get_agent("conversation", client=conv_client)
|
||||
actor = Actor(character, agent)
|
||||
@@ -226,13 +242,14 @@ async def load_scene_from_data(
|
||||
actor = Player(character, None)
|
||||
# Add the TestCharacter actor to the scene
|
||||
await scene.add_actor(actor)
|
||||
|
||||
|
||||
# the scene has been saved before (since we just loaded it), so we set the saved flag to True
|
||||
# as long as the scene has a memory_id.
|
||||
scene.saved = "memory_id" in scene_data
|
||||
|
||||
return scene
|
||||
|
||||
|
||||
async def load_character_into_scene(scene, scene_json_path, character_name):
|
||||
"""
|
||||
Load a character from a scene json file and add it to the current scene.
|
||||
@@ -244,10 +261,9 @@ async def load_character_into_scene(scene, scene_json_path, character_name):
|
||||
# Load the json file
|
||||
with open(scene_json_path, "r") as f:
|
||||
scene_data = json.load(f)
|
||||
|
||||
|
||||
|
||||
agent = scene.get_helper("conversation").agent
|
||||
|
||||
|
||||
# Find the character in the characters list
|
||||
for character_data in scene_data["characters"]:
|
||||
if character_data["name"] == character_name:
|
||||
@@ -264,7 +280,9 @@ async def load_character_into_scene(scene, scene_json_path, character_name):
|
||||
await scene.add_actor(actor)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Character '{character_name}' not found in the scene file '{scene_json_path}'")
|
||||
raise ValueError(
|
||||
f"Character '{character_name}' not found in the scene file '{scene_json_path}'"
|
||||
)
|
||||
|
||||
return scene
|
||||
|
||||
@@ -340,49 +358,47 @@ def default_player_character():
|
||||
|
||||
|
||||
def _load_history(history):
|
||||
|
||||
_history = []
|
||||
|
||||
|
||||
for text in history:
|
||||
|
||||
if isinstance(text, str):
|
||||
_history.append(_prepare_legacy_history(text))
|
||||
|
||||
|
||||
elif isinstance(text, dict):
|
||||
_history.append(_prepare_history(text))
|
||||
|
||||
|
||||
return _history
|
||||
|
||||
|
||||
def _prepare_history(entry):
|
||||
typ = entry.pop("typ", "scene_message")
|
||||
entry.pop("id", None)
|
||||
|
||||
|
||||
if entry.get("source") == "":
|
||||
entry.pop("source")
|
||||
|
||||
|
||||
cls = MESSAGES.get(typ, SceneMessage)
|
||||
|
||||
|
||||
return cls(**entry)
|
||||
|
||||
|
||||
|
||||
def _prepare_legacy_history(entry):
|
||||
|
||||
"""
|
||||
Convers legacy history to new format
|
||||
|
||||
|
||||
Legacy: list<str>
|
||||
New: list<SceneMessage>
|
||||
"""
|
||||
|
||||
|
||||
if entry.startswith("*"):
|
||||
cls = NarratorMessage
|
||||
elif entry.startswith("Director instructs"):
|
||||
cls = DirectorMessage
|
||||
else:
|
||||
cls = CharacterMessage
|
||||
|
||||
|
||||
return cls(entry)
|
||||
|
||||
|
||||
|
||||
def creative_environment():
|
||||
return {
|
||||
@@ -392,6 +408,5 @@ def creative_environment():
|
||||
"history": [],
|
||||
"archived_history": [],
|
||||
"character_states": {},
|
||||
"characters": [
|
||||
],
|
||||
"characters": [],
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .base import Prompt, LoopedPrompt
|
||||
from .base import LoopedPrompt, Prompt
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +1,32 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pydantic
|
||||
|
||||
current_prompt_context = ContextVar("current_content_context", default=None)
|
||||
|
||||
|
||||
class PromptContextState(pydantic.BaseModel):
|
||||
content: list[str] = pydantic.Field(default_factory=list)
|
||||
|
||||
def push(self, content:str, proxy:list[str]):
|
||||
|
||||
def push(self, content: str, proxy: list[str]):
|
||||
if content not in self.content:
|
||||
self.content.append(content)
|
||||
proxy.append(content)
|
||||
|
||||
def has(self, content:str):
|
||||
|
||||
def has(self, content: str):
|
||||
return content in self.content
|
||||
|
||||
def extend(self, content:list[str], proxy:list[str]):
|
||||
|
||||
def extend(self, content: list[str], proxy: list[str]):
|
||||
for item in content:
|
||||
self.push(item, proxy)
|
||||
|
||||
|
||||
|
||||
class PromptContext:
|
||||
|
||||
def __enter__(self):
|
||||
self.state = PromptContextState()
|
||||
self.token = current_prompt_context.set(self.state)
|
||||
return self.state
|
||||
|
||||
|
||||
def __exit__(self, *args):
|
||||
current_prompt_context.reset(self.token)
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
A roleplaying session between a user and a talented actor. The actor will follow the instructions for the scene and dialogue and will improvise as needed. The actor will only respond as one character.
|
||||
@@ -0,0 +1 @@
|
||||
A chat between a user and a talented fiction narrator. The narrator will describe scenes and characters based on stories provided to him in easy-to-read and easy-to-understand yet exciting detail. The narrator will never remind us that what he writes is fictional.
|
||||
@@ -0,0 +1 @@
|
||||
A chat between a user and a talented fiction director. The director will give instructions to a specific character to help them guide the story towards a specific goal.
|
||||
@@ -0,0 +1 @@
|
||||
A chat between an author and a talented fiction editor. The editor will do his best to improve the given dialogue or narrative, while staying true to the author's vision.
|
||||
@@ -0,0 +1 @@
|
||||
A chat between a user and a talented fiction narrator.The narrator will describe scenes and characters based on stories provided to him in easy-to-read and easy-to-understand yet exciting detail. The narrator will never remind us that what he writes is fictional.
|
||||
@@ -0,0 +1 @@
|
||||
A chat between a user and a talented fiction narrator. The narrator will summarize the given text according to the instructions, making sure to keep the overall tone of the narrative and dialogue.
|
||||
@@ -0,0 +1 @@
|
||||
Instructions for a talented story analyst. The analyst will analyze parts of a story or dialogue and give truthful answers based on the dialogue or events given to him. The analyst will never make up facts or lie in his answers.
|
||||
@@ -0,0 +1 @@
|
||||
Instructions for a talented story analyst. The analyst will analyze parts of a story or dialogue and give truthful answers based on the dialogue or events given to him. The analyst will never make up facts or lie in his answers. The analyst loves making JSON lists.
|
||||
@@ -41,7 +41,7 @@ class CharacterHub:
|
||||
|
||||
if not os.path.exists("scenes/characters"):
|
||||
os.makedirs("scenes/characters")
|
||||
|
||||
|
||||
with open(f"scenes/characters/{filename}.png", "wb") as f:
|
||||
f.write(result.content)
|
||||
|
||||
|
||||
@@ -2,11 +2,9 @@ import json
|
||||
|
||||
from talemate.scene_message import SceneMessage
|
||||
|
||||
|
||||
class SceneEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, SceneMessage):
|
||||
return obj.__dict__()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,41 +1,40 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import os
|
||||
import pydantic
|
||||
import hashlib
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pydantic
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Scene
|
||||
|
||||
|
||||
import structlog
|
||||
|
||||
__all__ = [
|
||||
"Asset",
|
||||
"SceneAssets"
|
||||
]
|
||||
__all__ = ["Asset", "SceneAssets"]
|
||||
|
||||
log = structlog.get_logger("talemate.scene_assets")
|
||||
|
||||
|
||||
class Asset(pydantic.BaseModel):
|
||||
id: str
|
||||
file_type: str
|
||||
media_type: str
|
||||
|
||||
def to_base64(self, asset_directory:str) -> str:
|
||||
|
||||
|
||||
def to_base64(self, asset_directory: str) -> str:
|
||||
"""
|
||||
Returns the asset as a base64 encoded string.
|
||||
"""
|
||||
|
||||
|
||||
asset_path = os.path.join(asset_directory, f"{self.id}.{self.file_type}")
|
||||
|
||||
|
||||
with open(asset_path, "rb") as f:
|
||||
return base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
class SceneAssets:
|
||||
|
||||
def __init__(self, scene:Scene):
|
||||
|
||||
class SceneAssets:
|
||||
def __init__(self, scene: Scene):
|
||||
self.scene = scene
|
||||
self.assets = {}
|
||||
self.cover_image = None
|
||||
@@ -45,64 +44,64 @@ class SceneAssets:
|
||||
"""
|
||||
Returns the scene's asset path
|
||||
"""
|
||||
|
||||
|
||||
scene_save_dir = self.scene.save_dir
|
||||
|
||||
|
||||
if not os.path.exists(scene_save_dir):
|
||||
raise FileNotFoundError(f"Scene save directory does not exist: {scene_save_dir}")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Scene save directory does not exist: {scene_save_dir}"
|
||||
)
|
||||
|
||||
asset_path = os.path.join(scene_save_dir, "assets")
|
||||
|
||||
|
||||
if not os.path.exists(asset_path):
|
||||
os.makedirs(asset_path)
|
||||
|
||||
|
||||
return asset_path
|
||||
|
||||
def asset_path(self, asset_id:str) -> str:
|
||||
|
||||
|
||||
def asset_path(self, asset_id: str) -> str:
|
||||
"""
|
||||
Returns the path to the asset with the given id.
|
||||
"""
|
||||
try:
|
||||
return os.path.join(self.asset_directory, f"{asset_id}.{self.assets[asset_id].file_type}")
|
||||
return os.path.join(
|
||||
self.asset_directory, f"{asset_id}.{self.assets[asset_id].file_type}"
|
||||
)
|
||||
except KeyError:
|
||||
log.error("asset_path", asset_id=asset_id, assets=self.assets)
|
||||
return None
|
||||
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
return {
|
||||
"cover_image": self.cover_image,
|
||||
"assets": {
|
||||
asset.id: asset.dict() for asset in self.assets.values()
|
||||
}
|
||||
"assets": {asset.id: asset.dict() for asset in self.assets.values()},
|
||||
}
|
||||
|
||||
|
||||
def load_assets(self, assets_dict:dict):
|
||||
|
||||
"""
|
||||
Loads assets from a dictionary.
|
||||
"""
|
||||
|
||||
for asset_id, asset_dict in assets_dict.items():
|
||||
self.assets[asset_id] = Asset(**asset_dict)
|
||||
|
||||
def set_cover_image(self, asset_bytes:bytes, file_extension:str, media_type:str):
|
||||
|
||||
def load_assets(self, assets_dict: dict):
|
||||
"""
|
||||
Loads assets from a dictionary.
|
||||
"""
|
||||
|
||||
for asset_id, asset_dict in assets_dict.items():
|
||||
self.assets[asset_id] = Asset(**asset_dict)
|
||||
|
||||
def set_cover_image(self, asset_bytes: bytes, file_extension: str, media_type: str):
|
||||
# add the asset
|
||||
|
||||
|
||||
asset = self.add_asset(asset_bytes, file_extension, media_type)
|
||||
self.cover_image = asset.id
|
||||
|
||||
def set_cover_image_from_file_path(self, file_path:str):
|
||||
|
||||
def set_cover_image_from_file_path(self, file_path: str):
|
||||
"""
|
||||
Sets the cover image from file path, calling add_asset_from_file_path
|
||||
"""
|
||||
|
||||
|
||||
asset = self.add_asset_from_file_path(file_path)
|
||||
self.cover_image = asset.id
|
||||
|
||||
def add_asset(self, asset_bytes:bytes, file_extension:str, media_type:str) -> Asset:
|
||||
self.cover_image = asset.id
|
||||
|
||||
def add_asset(
|
||||
self, asset_bytes: bytes, file_extension: str, media_type: str
|
||||
) -> Asset:
|
||||
"""
|
||||
Takes the asset and stores it in the scene's assets folder.
|
||||
"""
|
||||
@@ -131,39 +130,36 @@ class SceneAssets:
|
||||
self.assets[asset_id] = asset
|
||||
|
||||
return asset
|
||||
|
||||
|
||||
def add_asset_from_image_data(self, image_data:str) -> Asset:
|
||||
def add_asset_from_image_data(self, image_data: str) -> Asset:
|
||||
"""
|
||||
Will add an asset from an image data, extracting media type from the
|
||||
data url and then decoding the base64 encoded data.
|
||||
|
||||
|
||||
Will call add_asset
|
||||
"""
|
||||
|
||||
|
||||
media_type = image_data.split(";")[0].split(":")[1]
|
||||
image_bytes = base64.b64decode(image_data.split(",")[1])
|
||||
file_extension = media_type.split("/")[1]
|
||||
|
||||
|
||||
return self.add_asset(image_bytes, file_extension, media_type)
|
||||
|
||||
|
||||
def add_asset_from_file_path(self, file_path:str) -> Asset:
|
||||
|
||||
|
||||
def add_asset_from_file_path(self, file_path: str) -> Asset:
|
||||
"""
|
||||
Will add an asset from a file path, first loading the file into memory.
|
||||
and then calling add_asset
|
||||
"""
|
||||
|
||||
|
||||
file_bytes = None
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
|
||||
file_extension = os.path.splitext(file_path)[1]
|
||||
|
||||
|
||||
# guess media type from extension, currently only supports images
|
||||
# for png, jpg and webp
|
||||
|
||||
|
||||
if file_extension == ".png":
|
||||
media_type = "image/png"
|
||||
elif file_extension in [".jpg", ".jpeg"]:
|
||||
@@ -172,50 +168,44 @@ class SceneAssets:
|
||||
media_type = "image/webp"
|
||||
else:
|
||||
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||
|
||||
return self.add_asset(file_bytes, file_extension, media_type)
|
||||
|
||||
def get_asset(self, asset_id:str) -> Asset:
|
||||
|
||||
|
||||
return self.add_asset(file_bytes, file_extension, media_type)
|
||||
|
||||
def get_asset(self, asset_id: str) -> Asset:
|
||||
"""
|
||||
Returns the asset with the given id.
|
||||
"""
|
||||
|
||||
|
||||
return self.assets[asset_id]
|
||||
|
||||
def get_asset_bytes(self, asset_id:str) -> bytes:
|
||||
|
||||
|
||||
def get_asset_bytes(self, asset_id: str) -> bytes:
|
||||
"""
|
||||
Returns the bytes of the asset with the given id.
|
||||
"""
|
||||
|
||||
|
||||
asset_path = self.asset_path(asset_id)
|
||||
|
||||
|
||||
with open(asset_path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def get_asset_bytes_as_base64(self, asset_id:str) -> str:
|
||||
|
||||
|
||||
def get_asset_bytes_as_base64(self, asset_id: str) -> str:
|
||||
"""
|
||||
Returns the bytes of the asset with the given id as a base64 encoded string.
|
||||
"""
|
||||
|
||||
|
||||
bytes = self.get_asset_bytes(asset_id)
|
||||
|
||||
|
||||
return base64.b64encode(bytes).decode("utf-8")
|
||||
|
||||
def remove_asset(self, asset_id:str):
|
||||
|
||||
|
||||
def remove_asset(self, asset_id: str):
|
||||
"""
|
||||
Removes the asset with the given id.
|
||||
"""
|
||||
|
||||
|
||||
asset = self.assets.pop(asset_id)
|
||||
|
||||
|
||||
asset_path = self.asset_directory
|
||||
|
||||
|
||||
asset_file_path = os.path.join(asset_path, f"{asset_id}.{asset.file_type}")
|
||||
|
||||
os.remove(asset_file_path)
|
||||
|
||||
|
||||
os.remove(asset_file_path)
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import isodate
|
||||
|
||||
_message_id = 0
|
||||
|
||||
|
||||
def get_message_id():
|
||||
global _message_id
|
||||
_message_id += 1
|
||||
return _message_id
|
||||
|
||||
|
||||
def reset_message_id():
|
||||
global _message_id
|
||||
_message_id = 0
|
||||
@@ -15,38 +18,37 @@ def reset_message_id():
|
||||
|
||||
@dataclass
|
||||
class SceneMessage:
|
||||
|
||||
|
||||
"""
|
||||
Base class for all messages that are sent to the scene.
|
||||
"""
|
||||
|
||||
|
||||
# the mesage itself
|
||||
message: str
|
||||
|
||||
|
||||
# the id of the message
|
||||
id: int = field(default_factory=get_message_id)
|
||||
|
||||
|
||||
# the source of the message (e.g. "ai", "progress_story", "director")
|
||||
source: str = ""
|
||||
|
||||
|
||||
typ = "scene"
|
||||
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
def __int__(self):
|
||||
return self.id
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.message)
|
||||
|
||||
|
||||
def __in__(self, other):
|
||||
return (other in self.message)
|
||||
|
||||
return other in self.message
|
||||
|
||||
def __contains__(self, other):
|
||||
return (self.message in other)
|
||||
|
||||
return self.message in other
|
||||
|
||||
def __dict__(self):
|
||||
return {
|
||||
"message": self.message,
|
||||
@@ -54,66 +56,69 @@ class SceneMessage:
|
||||
"typ": self.typ,
|
||||
"source": self.source,
|
||||
}
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.message)
|
||||
|
||||
|
||||
def split(self, *args, **kwargs):
|
||||
return self.message.split(*args, **kwargs)
|
||||
|
||||
|
||||
def startswith(self, *args, **kwargs):
|
||||
return self.message.startswith(*args, **kwargs)
|
||||
|
||||
|
||||
def endswith(self, *args, **kwargs):
|
||||
return self.message.endswith(*args, **kwargs)
|
||||
|
||||
|
||||
@property
|
||||
def secondary_source(self):
|
||||
return self.source
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class CharacterMessage(SceneMessage):
|
||||
typ = "character"
|
||||
source: str = "ai"
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
@property
|
||||
def character_name(self):
|
||||
return self.message.split(":", 1)[0]
|
||||
|
||||
|
||||
@property
|
||||
def secondary_source(self):
|
||||
return self.character_name
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class NarratorMessage(SceneMessage):
|
||||
source: str = "progress_story"
|
||||
typ = "narrator"
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectorMessage(SceneMessage):
|
||||
typ = "director"
|
||||
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
The director message is a special case and needs to be transformed
|
||||
from "Director instructs {charname}:" to "*{charname} inner monologue:"
|
||||
"""
|
||||
|
||||
|
||||
transformed_message = self.message.replace("Director instructs ", "")
|
||||
char_name, message = transformed_message.split(":", 1)
|
||||
|
||||
|
||||
return f"[Story progression instructions for {char_name}: {message}]"
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimePassageMessage(SceneMessage):
|
||||
ts: str = "PT0S"
|
||||
source: str = "manual"
|
||||
source: str = "manual"
|
||||
typ = "time"
|
||||
|
||||
|
||||
def __dict__(self):
|
||||
return {
|
||||
"message": self.message,
|
||||
@@ -122,15 +127,17 @@ class TimePassageMessage(SceneMessage):
|
||||
"source": self.source,
|
||||
"ts": self.ts,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReinforcementMessage(SceneMessage):
|
||||
typ = "reinforcement"
|
||||
|
||||
|
||||
def __str__(self):
|
||||
question, _ = self.source.split(":", 1)
|
||||
return f"[Context state: {question}: {self.message}]"
|
||||
|
||||
|
||||
|
||||
MESSAGES = {
|
||||
"scene": SceneMessage,
|
||||
"character": CharacterMessage,
|
||||
@@ -138,4 +145,4 @@ MESSAGES = {
|
||||
"director": DirectorMessage,
|
||||
"time": TimePassageMessage,
|
||||
"reinforcement": ReinforcementMessage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,40 +1,38 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
import starlette.websockets
|
||||
import websockets
|
||||
import structlog
|
||||
import traceback
|
||||
|
||||
import starlette.websockets
|
||||
import structlog
|
||||
import websockets
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate import Scene
|
||||
from talemate import VERSION, Scene
|
||||
from talemate.config import load_config
|
||||
from talemate.load import load_scene
|
||||
from talemate.server.websocket_server import WebsocketHandler
|
||||
from talemate.config import load_config
|
||||
|
||||
from talemate import VERSION
|
||||
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
|
||||
async def websocket_endpoint(websocket, path):
|
||||
# Create a queue for outgoing messages
|
||||
message_queue = asyncio.Queue()
|
||||
handler = WebsocketHandler(websocket, message_queue)
|
||||
scene_task = None
|
||||
|
||||
|
||||
log.info("frontend connected")
|
||||
|
||||
try:
|
||||
# Create a task to send messages from the queue
|
||||
async def send_messages():
|
||||
while True:
|
||||
|
||||
# check if there are messages in the queue
|
||||
if message_queue.empty():
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
|
||||
message = await message_queue.get()
|
||||
await websocket.send(json.dumps(message))
|
||||
|
||||
@@ -49,15 +47,19 @@ async def websocket_endpoint(websocket, path):
|
||||
send_status_task = asyncio.create_task(send_status())
|
||||
|
||||
# create a task that will retriece client boostrap information
|
||||
|
||||
|
||||
async def send_client_bootstraps():
|
||||
while True:
|
||||
try:
|
||||
await instance.sync_client_bootstraps()
|
||||
except Exception as e:
|
||||
log.error("send_client_bootstraps", error=e, traceback=traceback.format_exc())
|
||||
log.error(
|
||||
"send_client_bootstraps",
|
||||
error=e,
|
||||
traceback=traceback.format_exc(),
|
||||
)
|
||||
await asyncio.sleep(15)
|
||||
|
||||
|
||||
send_client_bootstraps_task = asyncio.create_task(send_client_bootstraps())
|
||||
|
||||
while True:
|
||||
@@ -66,7 +68,7 @@ async def websocket_endpoint(websocket, path):
|
||||
action_type = data.get("type")
|
||||
|
||||
scene_data = None
|
||||
|
||||
|
||||
log.debug("frontend message", action_type=action_type)
|
||||
|
||||
if action_type == "load_scene":
|
||||
@@ -86,16 +88,18 @@ async def websocket_endpoint(websocket, path):
|
||||
"message": "Scene file loaded ...",
|
||||
"id": "scene.loaded",
|
||||
"status": "success",
|
||||
"data": {"hidden":True}
|
||||
"data": {"hidden": True},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if scene_data and filename:
|
||||
file_path = handler.handle_character_card_upload(scene_data, filename)
|
||||
|
||||
file_path = handler.handle_character_card_upload(
|
||||
scene_data, filename
|
||||
)
|
||||
|
||||
log.info("load_scene", file_path=file_path, reset=reset)
|
||||
|
||||
# Create a task to load the scene in the background
|
||||
# Create a task to load the scene in the background
|
||||
scene_task = asyncio.create_task(
|
||||
handler.load_scene(
|
||||
file_path, reset=reset, callback=scene_loading_done
|
||||
@@ -140,11 +144,7 @@ async def websocket_endpoint(websocket, path):
|
||||
elif action_type == "request_app_config":
|
||||
log.info("request_app_config")
|
||||
await message_queue.put(
|
||||
{
|
||||
"type": "app_config",
|
||||
"data": load_config(),
|
||||
"version": VERSION
|
||||
}
|
||||
{"type": "app_config", "data": load_config(), "version": VERSION}
|
||||
)
|
||||
else:
|
||||
log.info("Routing to sub-handler", action_type=action_type)
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import os
|
||||
import pydantic
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.tale_mate import Character, Actor, Player
|
||||
|
||||
from typing import Union
|
||||
from talemate.tale_mate import Actor, Character, Player
|
||||
|
||||
log = structlog.get_logger("talemate.server.character_creator")
|
||||
|
||||
|
||||
|
||||
class StepData(pydantic.BaseModel):
|
||||
template: str
|
||||
is_player_character: bool
|
||||
@@ -26,69 +24,70 @@ class StepData(pydantic.BaseModel):
|
||||
description: str = None
|
||||
questions: list[str] = []
|
||||
scenario_context: str = None
|
||||
|
||||
|
||||
|
||||
class CharacterCreationData(pydantic.BaseModel):
|
||||
base_attributes: dict[str, str] = {}
|
||||
is_player_character: bool = False
|
||||
character: object = None
|
||||
initial_prompt: str = None
|
||||
scenario_context: str = None
|
||||
|
||||
|
||||
|
||||
class CharacterCreatorServerPlugin:
|
||||
|
||||
router = "character_creator"
|
||||
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
self.character_creation_data = None
|
||||
|
||||
|
||||
@property
|
||||
def scene(self):
|
||||
return self.websocket_handler.scene
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
|
||||
async def handle(self, data: dict):
|
||||
action = data.get("action")
|
||||
|
||||
|
||||
log.info("Character creator action", action=action)
|
||||
|
||||
|
||||
if action == "submit":
|
||||
|
||||
step = data.get("step")
|
||||
|
||||
|
||||
fn = getattr(self, f"handle_submit_step{step}", None)
|
||||
|
||||
|
||||
if fn is None:
|
||||
raise NotImplementedError(f"Unknown step {step}")
|
||||
|
||||
|
||||
return await fn(data)
|
||||
|
||||
|
||||
elif action == "request_templates":
|
||||
return await self.handle_request_templates(data)
|
||||
|
||||
|
||||
async def handle_request_templates(self, data:dict):
|
||||
choices = Prompt.get("creator.character-human",{}).list_templates("character-attributes-*.jinja2")
|
||||
|
||||
|
||||
async def handle_request_templates(self, data: dict):
|
||||
choices = Prompt.get("creator.character-human", {}).list_templates(
|
||||
"character-attributes-*.jinja2"
|
||||
)
|
||||
|
||||
# each choice is a filename, we want to remove the extension and the directory
|
||||
choices = [os.path.splitext(os.path.basename(c))[0] for c in choices]
|
||||
|
||||
|
||||
# finally also strip the 'character-' prefix
|
||||
choices = [c.replace("character-attributes-", "") for c in choices]
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_creator",
|
||||
"action": "send_templates",
|
||||
"templates": sorted(choices),
|
||||
"content_context": self.websocket_handler.config["creator"]["content_context"],
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "send_templates",
|
||||
"templates": sorted(choices),
|
||||
"content_context": self.websocket_handler.config["creator"][
|
||||
"content_context"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
|
||||
def apply_step_data(self, data:dict):
|
||||
def apply_step_data(self, data: dict):
|
||||
step_data = StepData(**data)
|
||||
if not self.character_creation_data:
|
||||
self.character_creation_data = CharacterCreationData(
|
||||
@@ -97,7 +96,7 @@ class CharacterCreatorServerPlugin:
|
||||
is_player_character=step_data.is_player_character,
|
||||
)
|
||||
|
||||
character=Character(
|
||||
character = Character(
|
||||
name="",
|
||||
description="",
|
||||
greeting_text="",
|
||||
@@ -111,22 +110,25 @@ class CharacterCreatorServerPlugin:
|
||||
character.gender = step_data.base_attributes.get("gender")
|
||||
character.color = "red"
|
||||
character.example_dialogue = step_data.dialogue_examples
|
||||
|
||||
|
||||
return character, step_data
|
||||
|
||||
async def handle_submit_step2(self, data:dict):
|
||||
|
||||
async def handle_submit_step2(self, data: dict):
|
||||
creator = self.scene.get_helper("creator").agent
|
||||
character, step_data = self.apply_step_data(data)
|
||||
|
||||
|
||||
self.emit_step_start(2)
|
||||
|
||||
def emit_attribute(name, value):
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_creator",
|
||||
"action": "base_attribute",
|
||||
"name": name,
|
||||
"value": value,
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "base_attribute",
|
||||
"name": name,
|
||||
"value": value,
|
||||
}
|
||||
)
|
||||
|
||||
base_attributes = await creator.create_character_attributes(
|
||||
step_data.character_prompt,
|
||||
step_data.template,
|
||||
@@ -136,53 +138,54 @@ class CharacterCreatorServerPlugin:
|
||||
custom_attributes=step_data.custom_attributes,
|
||||
predefined_attributes=step_data.base_attributes,
|
||||
)
|
||||
|
||||
|
||||
base_attributes["scenario_context"] = step_data.scenario_context
|
||||
|
||||
|
||||
character.base_attributes = base_attributes
|
||||
character.gender = base_attributes["gender"]
|
||||
character.name = base_attributes["name"]
|
||||
|
||||
|
||||
log.info("Created character", name=base_attributes.get("name"))
|
||||
|
||||
|
||||
self.emit_step_done(2)
|
||||
|
||||
async def handle_submit_step3(self, data:dict):
|
||||
|
||||
|
||||
async def handle_submit_step3(self, data: dict):
|
||||
creator = self.scene.get_helper("creator").agent
|
||||
character, step_data = self.apply_step_data(data)
|
||||
|
||||
|
||||
self.emit_step_start(3)
|
||||
|
||||
|
||||
description = await creator.create_character_description(
|
||||
character,
|
||||
content_context=step_data.scenario_context,
|
||||
)
|
||||
|
||||
|
||||
character.description = description
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_creator",
|
||||
"action": "description",
|
||||
"description": description,
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "description",
|
||||
"description": description,
|
||||
}
|
||||
)
|
||||
|
||||
self.emit_step_done(3)
|
||||
|
||||
|
||||
async def handle_submit_step4(self, data:dict):
|
||||
|
||||
async def handle_submit_step4(self, data: dict):
|
||||
creator = self.scene.get_helper("creator").agent
|
||||
character, step_data = self.apply_step_data(data)
|
||||
|
||||
def emit_detail(question, answer):
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_creator",
|
||||
"action": "detail",
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "detail",
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
}
|
||||
)
|
||||
|
||||
self.emit_step_start(4)
|
||||
|
||||
character_details = await creator.create_character_details(
|
||||
@@ -193,28 +196,30 @@ class CharacterCreatorServerPlugin:
|
||||
content_context=self.character_creation_data.scenario_context,
|
||||
)
|
||||
character.details = list(character_details.values())
|
||||
|
||||
|
||||
self.emit_step_done(4)
|
||||
|
||||
|
||||
async def handle_submit_step5(self, data:dict):
|
||||
|
||||
async def handle_submit_step5(self, data: dict):
|
||||
creator = self.scene.get_helper("creator").agent
|
||||
character, step_data = self.apply_step_data(data)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_creator",
|
||||
"action": "set_generating_step",
|
||||
"step": 5,
|
||||
})
|
||||
|
||||
def emit_example(key, example):
|
||||
self.websocket_handler.queue_put({
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "example_dialogue",
|
||||
"example": example,
|
||||
})
|
||||
|
||||
"action": "set_generating_step",
|
||||
"step": 5,
|
||||
}
|
||||
)
|
||||
|
||||
def emit_example(key, example):
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "example_dialogue",
|
||||
"example": example,
|
||||
}
|
||||
)
|
||||
|
||||
dialogue_guide = await creator.create_character_example_dialogue(
|
||||
character,
|
||||
step_data.template,
|
||||
@@ -223,63 +228,61 @@ class CharacterCreatorServerPlugin:
|
||||
content_context=self.character_creation_data.scenario_context,
|
||||
example_callback=emit_example,
|
||||
)
|
||||
|
||||
|
||||
character.dialogue_guide = dialogue_guide
|
||||
self.emit_step_done(5)
|
||||
|
||||
|
||||
async def handle_submit_step6(self, data:dict):
|
||||
|
||||
async def handle_submit_step6(self, data: dict):
|
||||
character, step_data = self.apply_step_data(data)
|
||||
|
||||
|
||||
# check if acter with character name already exists
|
||||
|
||||
|
||||
for actor in self.scene.actors:
|
||||
if actor.character.name == character.name:
|
||||
|
||||
if character.is_player and not actor.character.is_player:
|
||||
log.info("Character already exists, but is not a player", name=character.name)
|
||||
|
||||
log.info(
|
||||
"Character already exists, but is not a player",
|
||||
name=character.name,
|
||||
)
|
||||
|
||||
await self.scene.remove_actor(actor)
|
||||
break
|
||||
|
||||
|
||||
log.info("Character already exists", name=character.name)
|
||||
actor.character = character
|
||||
self.scene.emit_status()
|
||||
self.emit_step_done(6)
|
||||
return
|
||||
|
||||
|
||||
if character.is_player:
|
||||
actor = Player(character, self.scene.get_helper("conversation").agent)
|
||||
else:
|
||||
actor = Actor(character, self.scene.get_helper("conversation").agent)
|
||||
|
||||
|
||||
log.info("Adding actor", name=character.name, actor=actor)
|
||||
|
||||
|
||||
character.base_attributes["scenario_context"] = step_data.scenario_context
|
||||
character.base_attributes["_template"] = step_data.template
|
||||
character.base_attributes["_prompt"] = step_data.character_prompt
|
||||
|
||||
|
||||
|
||||
await self.scene.add_actor(actor)
|
||||
self.scene.emit_status()
|
||||
self.emit_step_done(6)
|
||||
|
||||
|
||||
def emit_step_start(self, step):
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_creator",
|
||||
"action": "set_generating_step",
|
||||
"step": step,
|
||||
})
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "set_generating_step",
|
||||
"step": step,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def emit_step_done(self, step):
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_creator",
|
||||
"action": "set_generating_step_done",
|
||||
"step": step,
|
||||
})
|
||||
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_creator",
|
||||
"action": "set_generating_step_done",
|
||||
"step": step,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,73 +1,80 @@
|
||||
import os
|
||||
import pydantic
|
||||
import asyncio
|
||||
import structlog
|
||||
import json
|
||||
import os
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.load import load_character_into_scene
|
||||
|
||||
log = structlog.get_logger("talemate.server.character_importer")
|
||||
|
||||
|
||||
|
||||
class ListCharactersData(pydantic.BaseModel):
|
||||
scene_path: str
|
||||
|
||||
|
||||
|
||||
class ImportCharacterData(pydantic.BaseModel):
|
||||
scene_path: str
|
||||
character_name: str
|
||||
|
||||
|
||||
class CharacterImporterServerPlugin:
|
||||
|
||||
router = "character_importer"
|
||||
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
|
||||
|
||||
@property
|
||||
def scene(self):
|
||||
return self.websocket_handler.scene
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
|
||||
async def handle(self, data: dict):
|
||||
log.info("Character importer action", action=data.get("action"))
|
||||
|
||||
|
||||
fn = getattr(self, f"handle_{data.get('action')}", None)
|
||||
|
||||
|
||||
if fn is None:
|
||||
return
|
||||
|
||||
|
||||
await fn(data)
|
||||
|
||||
|
||||
async def handle_list_characters(self, data):
|
||||
|
||||
list_characters_data = ListCharactersData(**data)
|
||||
|
||||
|
||||
scene_path = list_characters_data.scene_path
|
||||
|
||||
|
||||
with open(scene_path, "r") as f:
|
||||
scene_data = json.load(f)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_importer",
|
||||
"action": "list_characters",
|
||||
"characters": [character["name"] for character in scene_data.get("characters", [])]
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_importer",
|
||||
"action": "list_characters",
|
||||
"characters": [
|
||||
character["name"] for character in scene_data.get("characters", [])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def handle_import(self, data):
|
||||
import_character_data = ImportCharacterData(**data)
|
||||
|
||||
|
||||
scene = self.websocket_handler.scene
|
||||
|
||||
|
||||
await load_character_into_scene(
|
||||
scene,
|
||||
import_character_data.scene_path,
|
||||
import_character_data.character_name,
|
||||
)
|
||||
|
||||
|
||||
scene.emit_status()
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "character_importer",
|
||||
"action": "import_character_done",
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "character_importer",
|
||||
"action": "import_character_done",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,165 +1,190 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate import VERSION
|
||||
from talemate.client.registry import CLIENT_CLASSES
|
||||
from talemate.config import Config as AppConfigData, load_config, save_config
|
||||
from talemate.client.model_prompts import model_prompt
|
||||
from talemate.client.registry import CLIENT_CLASSES
|
||||
from talemate.config import Config as AppConfigData
|
||||
from talemate.config import load_config, save_config
|
||||
from talemate.emit import emit
|
||||
|
||||
log = structlog.get_logger("talemate.server.config")
|
||||
|
||||
|
||||
class ConfigPayload(pydantic.BaseModel):
|
||||
config: AppConfigData
|
||||
|
||||
|
||||
|
||||
class DefaultCharacterPayload(pydantic.BaseModel):
|
||||
name: str
|
||||
gender: str
|
||||
description: str
|
||||
color: str = "#3362bb"
|
||||
|
||||
|
||||
|
||||
class SetLLMTemplatePayload(pydantic.BaseModel):
|
||||
template_file: str
|
||||
model: str
|
||||
|
||||
|
||||
|
||||
class DetermineLLMTemplatePayload(pydantic.BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
|
||||
class ConfigPlugin:
|
||||
|
||||
router = "config"
|
||||
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
|
||||
async def handle(self, data: dict):
|
||||
log.info("Config action", action=data.get("action"))
|
||||
|
||||
|
||||
fn = getattr(self, f"handle_{data.get('action')}", None)
|
||||
|
||||
|
||||
if fn is None:
|
||||
return
|
||||
|
||||
|
||||
await fn(data)
|
||||
|
||||
|
||||
async def handle_save(self, data):
|
||||
|
||||
app_config_data = ConfigPayload(**data)
|
||||
current_config = load_config()
|
||||
|
||||
|
||||
current_config.update(app_config_data.dict().get("config"))
|
||||
|
||||
|
||||
save_config(current_config)
|
||||
|
||||
|
||||
self.websocket_handler.config = current_config
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "app_config",
|
||||
"data": load_config(),
|
||||
"version": VERSION
|
||||
})
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "config",
|
||||
"action": "save_complete",
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{"type": "app_config", "data": load_config(), "version": VERSION}
|
||||
)
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "config",
|
||||
"action": "save_complete",
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_save_default_character(self, data):
|
||||
|
||||
log.info("Saving default character", data=data["data"])
|
||||
|
||||
|
||||
payload = DefaultCharacterPayload(**data["data"])
|
||||
|
||||
|
||||
current_config = load_config()
|
||||
|
||||
|
||||
current_config["game"]["default_player_character"] = payload.model_dump()
|
||||
|
||||
log.info("Saving default character", character=current_config["game"]["default_player_character"])
|
||||
|
||||
|
||||
log.info(
|
||||
"Saving default character",
|
||||
character=current_config["game"]["default_player_character"],
|
||||
)
|
||||
|
||||
save_config(current_config)
|
||||
|
||||
|
||||
self.websocket_handler.config = current_config
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "app_config",
|
||||
"data": load_config(),
|
||||
"version": VERSION
|
||||
})
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "config",
|
||||
"action": "save_default_character_complete",
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{"type": "app_config", "data": load_config(), "version": VERSION}
|
||||
)
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "config",
|
||||
"action": "save_default_character_complete",
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_request_std_llm_templates(self, data):
|
||||
|
||||
log.info("Requesting std llm templates")
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "config",
|
||||
"action": "std_llm_templates",
|
||||
"data": {
|
||||
"templates": model_prompt.std_templates,
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "config",
|
||||
"action": "std_llm_templates",
|
||||
"data": {
|
||||
"templates": model_prompt.std_templates,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
async def handle_set_llm_template(self, data):
|
||||
|
||||
payload = SetLLMTemplatePayload(**data["data"])
|
||||
|
||||
copied_to = model_prompt.create_user_override(payload.template_file, payload.model)
|
||||
|
||||
log.info("Copied template", copied_to=copied_to, template=payload.template_file, model=payload.model)
|
||||
|
||||
prompt_template_example, prompt_template_file = model_prompt(payload.model, "sysmsg", "prompt<|BOT|>{LLM coercion}")
|
||||
|
||||
log.info("Prompt template example", prompt_template_example=prompt_template_example, prompt_template_file=prompt_template_file)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "config",
|
||||
"action": "set_llm_template_complete",
|
||||
"data": {
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": True if prompt_template_example else False,
|
||||
"template_file": prompt_template_file,
|
||||
|
||||
copied_to = model_prompt.create_user_override(
|
||||
payload.template_file, payload.model
|
||||
)
|
||||
|
||||
log.info(
|
||||
"Copied template",
|
||||
copied_to=copied_to,
|
||||
template=payload.template_file,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
prompt_template_example, prompt_template_file = model_prompt(
|
||||
payload.model, "sysmsg", "prompt<|BOT|>{LLM coercion}"
|
||||
)
|
||||
|
||||
log.info(
|
||||
"Prompt template example",
|
||||
prompt_template_example=prompt_template_example,
|
||||
prompt_template_file=prompt_template_file,
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "config",
|
||||
"action": "set_llm_template_complete",
|
||||
"data": {
|
||||
"prompt_template_example": prompt_template_example,
|
||||
"has_prompt_template": True if prompt_template_example else False,
|
||||
"template_file": prompt_template_file,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
async def handle_determine_llm_template(self, data):
|
||||
|
||||
payload = DetermineLLMTemplatePayload(**data["data"])
|
||||
|
||||
|
||||
log.info("Determining LLM template", model=payload.model)
|
||||
|
||||
|
||||
template = model_prompt.query_hf_for_prompt_template_suggestion(payload.model)
|
||||
|
||||
|
||||
log.info("Template suggestion", template=template)
|
||||
|
||||
|
||||
if not template:
|
||||
emit("status", message="No template found for model", status="warning")
|
||||
else:
|
||||
await self.handle_set_llm_template({
|
||||
"data": {
|
||||
"template_file": template,
|
||||
"model": payload.model,
|
||||
await self.handle_set_llm_template(
|
||||
{
|
||||
"data": {
|
||||
"template_file": template,
|
||||
"model": payload.model,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "config",
|
||||
"action": "determine_llm_template_complete",
|
||||
"data": {
|
||||
"template": template,
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "config",
|
||||
"action": "determine_llm_template_complete",
|
||||
"data": {
|
||||
"template": template,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
async def handle_request_client_types(self, data):
|
||||
|
||||
log.info("Requesting client types")
|
||||
|
||||
|
||||
clients = {
|
||||
client_type: CLIENT_CLASSES[client_type].Meta().model_dump() for client_type in CLIENT_CLASSES
|
||||
client_type: CLIENT_CLASSES[client_type].Meta().model_dump()
|
||||
for client_type in CLIENT_CLASSES
|
||||
}
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "config",
|
||||
"action": "client_types",
|
||||
"data": clients,
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "config",
|
||||
"action": "client_types",
|
||||
"data": clients,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import uuid
|
||||
from typing import Any, Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
from typing import Union, Any
|
||||
import uuid
|
||||
|
||||
from talemate.config import load_config, save_config
|
||||
|
||||
@@ -11,43 +12,40 @@ log = structlog.get_logger("talemate.server.quick_settings")
|
||||
class SetQuickSettingsPayload(pydantic.BaseModel):
|
||||
setting: str
|
||||
value: Any
|
||||
|
||||
|
||||
|
||||
class QuickSettingsPlugin:
|
||||
router = "quick_settings"
|
||||
|
||||
|
||||
@property
|
||||
def scene(self):
|
||||
return self.websocket_handler.scene
|
||||
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
|
||||
async def handle(self, data: dict):
|
||||
log.info("quick settings action", action=data.get("action"))
|
||||
|
||||
|
||||
fn = getattr(self, f"handle_{data.get('action')}", None)
|
||||
|
||||
|
||||
if fn is None:
|
||||
return
|
||||
|
||||
|
||||
await fn(data)
|
||||
|
||||
async def handle_set(self, data:dict):
|
||||
|
||||
|
||||
async def handle_set(self, data: dict):
|
||||
payload = SetQuickSettingsPayload(**data)
|
||||
|
||||
|
||||
if payload.setting == "auto_save":
|
||||
self.scene.config["game"]["general"]["auto_save"] = payload.value
|
||||
elif payload.setting == "auto_progress":
|
||||
self.scene.config["game"]["general"]["auto_progress"] = payload.value
|
||||
else:
|
||||
raise NotImplementedError(f"Setting {payload.setting} not implemented.")
|
||||
|
||||
|
||||
save_config(self.scene.config)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": self.router,
|
||||
"action": "set_done",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{"type": self.router, "action": "set_done", "data": payload.model_dump()}
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
|
||||
@@ -10,13 +10,16 @@ from talemate.server.api import websocket_endpoint
|
||||
|
||||
log = structlog.get_logger("talemate.server.run")
|
||||
|
||||
|
||||
def run_server(args):
|
||||
"""
|
||||
Run the talemate web server using the provided arguments.
|
||||
|
||||
:param args: command line arguments parsed by argparse
|
||||
"""
|
||||
start_server = websockets.serve(websocket_endpoint, args.host, args.port, max_size=2 ** 23)
|
||||
start_server = websockets.serve(
|
||||
websocket_endpoint, args.host, args.port, max_size=2**23
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
log.info("talemate backend started", host=args.host, port=args.port)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
@@ -1,152 +1,166 @@
|
||||
import os
|
||||
import pydantic
|
||||
import asyncio
|
||||
import structlog
|
||||
import json
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.load import load_character_into_scene
|
||||
|
||||
log = structlog.get_logger("talemate.server.character_importer")
|
||||
|
||||
|
||||
|
||||
class ListScenesData(pydantic.BaseModel):
|
||||
scene_path: str
|
||||
|
||||
|
||||
|
||||
class CreateSceneData(pydantic.BaseModel):
|
||||
name: Union[str, None] = None
|
||||
description: Union[str, None] = None
|
||||
intro: Union[str, None] = None
|
||||
intro: Union[str, None] = None
|
||||
content_context: Union[str, None] = None
|
||||
prompt: Union[str, None] = None
|
||||
|
||||
|
||||
class SceneCreatorServerPlugin:
|
||||
|
||||
router = "scene_creator"
|
||||
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
|
||||
|
||||
@property
|
||||
def scene(self):
|
||||
return self.websocket_handler.scene
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
|
||||
async def handle(self, data: dict):
|
||||
log.info("Scene importer action", action=data.get("action"))
|
||||
|
||||
|
||||
fn = getattr(self, f"handle_{data.get('action')}", None)
|
||||
|
||||
|
||||
if fn is None:
|
||||
return
|
||||
|
||||
|
||||
await fn(data)
|
||||
|
||||
|
||||
async def handle_generate_description(self, data):
|
||||
create_scene_data = CreateSceneData(**data)
|
||||
|
||||
|
||||
scene = self.websocket_handler.scene
|
||||
|
||||
|
||||
creator = scene.get_helper("creator").agent
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
}
|
||||
)
|
||||
|
||||
description = await creator.create_scene_description(
|
||||
create_scene_data.prompt,
|
||||
create_scene_data.content_context,
|
||||
)
|
||||
|
||||
|
||||
log.info("Generated scene description", description=description)
|
||||
|
||||
|
||||
scene.description = description
|
||||
|
||||
|
||||
self.send_scene_update()
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_generate_name(self, data):
|
||||
|
||||
create_scene_data = CreateSceneData(**data)
|
||||
|
||||
|
||||
scene = self.websocket_handler.scene
|
||||
|
||||
|
||||
creator = scene.get_helper("creator").agent
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
}
|
||||
)
|
||||
|
||||
name = await creator.create_scene_name(
|
||||
create_scene_data.prompt,
|
||||
create_scene_data.content_context,
|
||||
scene.description,
|
||||
)
|
||||
|
||||
|
||||
log.info("Generated scene name", name=name)
|
||||
|
||||
|
||||
scene.name = name
|
||||
|
||||
|
||||
self.send_scene_update()
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_generate_intro(self, data):
|
||||
|
||||
create_scene_data = CreateSceneData(**data)
|
||||
|
||||
|
||||
scene = self.websocket_handler.scene
|
||||
|
||||
|
||||
creator = scene.get_helper("creator").agent
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
}
|
||||
)
|
||||
|
||||
intro = await creator.create_scene_intro(
|
||||
create_scene_data.prompt,
|
||||
create_scene_data.content_context,
|
||||
scene.description,
|
||||
scene.name,
|
||||
)
|
||||
|
||||
|
||||
log.info("Generated scene intro", intro=intro)
|
||||
|
||||
|
||||
scene.intro = intro
|
||||
|
||||
|
||||
self.send_scene_update()
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_generate(self, data):
|
||||
create_scene_data = CreateSceneData(**data)
|
||||
|
||||
|
||||
scene = self.websocket_handler.scene
|
||||
|
||||
|
||||
creator = scene.get_helper("creator").agent
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating",
|
||||
}
|
||||
)
|
||||
|
||||
description = await creator.create_scene_description(
|
||||
create_scene_data.prompt,
|
||||
create_scene_data.content_context,
|
||||
)
|
||||
|
||||
|
||||
log.info("Generated scene description", description=description)
|
||||
|
||||
name = await creator.create_scene_name(
|
||||
@@ -154,55 +168,61 @@ class SceneCreatorServerPlugin:
|
||||
create_scene_data.content_context,
|
||||
description,
|
||||
)
|
||||
|
||||
|
||||
log.info("Generated scene name", name=name)
|
||||
|
||||
|
||||
intro = await creator.create_scene_intro(
|
||||
create_scene_data.prompt,
|
||||
create_scene_data.content_context,
|
||||
description,
|
||||
name,
|
||||
)
|
||||
|
||||
|
||||
log.info("Generated scene intro", intro=intro)
|
||||
|
||||
|
||||
scene.name = name
|
||||
scene.description = description
|
||||
scene.intro = intro
|
||||
scene.scenario_context = create_scene_data.content_context
|
||||
|
||||
|
||||
self.send_scene_update()
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "set_generating_done",
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_edit(self, data):
|
||||
scene = self.websocket_handler.scene
|
||||
create_scene_data = CreateSceneData(**data)
|
||||
|
||||
|
||||
scene.description = create_scene_data.description
|
||||
scene.name = create_scene_data.name
|
||||
scene.intro = create_scene_data.intro
|
||||
scene.scenario_context = create_scene_data.content_context
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "scene_saved",
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "scene_saved",
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_load(self, data):
|
||||
self.send_scene_update()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
def send_scene_update(self):
|
||||
scene = self.websocket_handler.scene
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "scene_creator",
|
||||
"action": "scene_update",
|
||||
"description": scene.description,
|
||||
"name": scene.name,
|
||||
"intro": scene.intro,
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "scene_creator",
|
||||
"action": "scene_update",
|
||||
"description": scene.description,
|
||||
"name": scene.name,
|
||||
"intro": scene.intro,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -4,23 +4,21 @@ import talemate.instance as instance
|
||||
|
||||
log = structlog.get_logger("talemate.server.tts")
|
||||
|
||||
|
||||
class TTSPlugin:
|
||||
router = "tts"
|
||||
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
self.tts = None
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
|
||||
async def handle(self, data: dict):
|
||||
action = data.get("action")
|
||||
|
||||
|
||||
|
||||
if action == "test":
|
||||
return await self.handle_test(data)
|
||||
|
||||
async def handle_test(self, data:dict):
|
||||
|
||||
|
||||
async def handle_test(self, data: dict):
|
||||
tts_agent = instance.get_agent("tts")
|
||||
|
||||
await tts_agent.generate("Welcome to talemate!")
|
||||
|
||||
await tts_agent.generate("Welcome to talemate!")
|
||||
|
||||
@@ -2,24 +2,29 @@ import asyncio
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate import Helper, Scene
|
||||
from talemate.config import load_config, save_config, SceneAssetUpload
|
||||
from talemate.client.registry import CLIENT_CLASSES
|
||||
from talemate.config import SceneAssetUpload, load_config, save_config
|
||||
from talemate.emit import Emission, Receiver, abort_wait_for_input, emit
|
||||
from talemate.files import list_scenes_directory
|
||||
from talemate.load import load_scene, load_scene_from_data, load_scene_from_character_card
|
||||
from talemate.load import (
|
||||
load_scene,
|
||||
load_scene_from_character_card,
|
||||
load_scene_from_data,
|
||||
)
|
||||
from talemate.scene_assets import Asset
|
||||
|
||||
from talemate.client.registry import CLIENT_CLASSES
|
||||
|
||||
from talemate.server import character_creator
|
||||
from talemate.server import character_importer
|
||||
from talemate.server import scene_creator
|
||||
from talemate.server import config
|
||||
from talemate.server import world_state_manager
|
||||
from talemate.server import quick_settings
|
||||
from talemate.server import (
|
||||
character_creator,
|
||||
character_importer,
|
||||
config,
|
||||
quick_settings,
|
||||
scene_creator,
|
||||
world_state_manager,
|
||||
)
|
||||
|
||||
log = structlog.get_logger("talemate.server.websocket_server")
|
||||
|
||||
@@ -27,8 +32,6 @@ AGENT_INSTANCES = {}
|
||||
|
||||
|
||||
class WebsocketHandler(Receiver):
|
||||
|
||||
|
||||
def __init__(self, socket, out_queue, llm_clients=dict()):
|
||||
self.agents = {typ: {"name": typ} for typ in instance.agent_types()}
|
||||
self.socket = socket
|
||||
@@ -40,7 +43,7 @@ class WebsocketHandler(Receiver):
|
||||
|
||||
for name, agent_config in self.config.get("agents", {}).items():
|
||||
self.agents[name] = agent_config
|
||||
|
||||
|
||||
self.llm_clients = self.config.get("clients", llm_clients)
|
||||
|
||||
instance.get_agent("memory", self.scene)
|
||||
@@ -48,16 +51,26 @@ class WebsocketHandler(Receiver):
|
||||
# unconveniently named function, this `connect` method is called
|
||||
# to connect signals handlers to the websocket handler
|
||||
self.connect()
|
||||
|
||||
|
||||
self.connect_llm_clients()
|
||||
|
||||
self.routes = {
|
||||
character_creator.CharacterCreatorServerPlugin.router: character_creator.CharacterCreatorServerPlugin(self),
|
||||
character_importer.CharacterImporterServerPlugin.router: character_importer.CharacterImporterServerPlugin(self),
|
||||
scene_creator.SceneCreatorServerPlugin.router: scene_creator.SceneCreatorServerPlugin(self),
|
||||
self.routes = {
|
||||
character_creator.CharacterCreatorServerPlugin.router: character_creator.CharacterCreatorServerPlugin(
|
||||
self
|
||||
),
|
||||
character_importer.CharacterImporterServerPlugin.router: character_importer.CharacterImporterServerPlugin(
|
||||
self
|
||||
),
|
||||
scene_creator.SceneCreatorServerPlugin.router: scene_creator.SceneCreatorServerPlugin(
|
||||
self
|
||||
),
|
||||
config.ConfigPlugin.router: config.ConfigPlugin(self),
|
||||
world_state_manager.WorldStateManagerPlugin.router: world_state_manager.WorldStateManagerPlugin(self),
|
||||
quick_settings.QuickSettingsPlugin.router: quick_settings.QuickSettingsPlugin(self),
|
||||
world_state_manager.WorldStateManagerPlugin.router: world_state_manager.WorldStateManagerPlugin(
|
||||
self
|
||||
),
|
||||
quick_settings.QuickSettingsPlugin.router: quick_settings.QuickSettingsPlugin(
|
||||
self
|
||||
),
|
||||
}
|
||||
|
||||
# self.request_scenes_list()
|
||||
@@ -85,34 +98,36 @@ class WebsocketHandler(Receiver):
|
||||
log.error("Error connecting to client", client_name=client_name, e=e)
|
||||
continue
|
||||
|
||||
log.info("Configured client", client_name=client_name, client_type=client.client_type)
|
||||
log.info(
|
||||
"Configured client",
|
||||
client_name=client_name,
|
||||
client_type=client.client_type,
|
||||
)
|
||||
|
||||
self.connect_agents()
|
||||
|
||||
|
||||
def connect_agents(self):
|
||||
|
||||
if not self.llm_clients:
|
||||
instance.emit_agents_status()
|
||||
return
|
||||
|
||||
|
||||
for agent_typ, agent_config in self.agents.items():
|
||||
try:
|
||||
client = self.llm_clients.get(agent_config.get("client"))["client"]
|
||||
except TypeError as e:
|
||||
client = None
|
||||
|
||||
|
||||
if not client:
|
||||
# select first client
|
||||
print("selecting first client", self.llm_clients)
|
||||
client = list(self.llm_clients.values())[0]["client"]
|
||||
agent_config["client"] = client.name
|
||||
|
||||
|
||||
log.debug("Linked agent", agent_typ=agent_typ, client=client.name)
|
||||
agent = instance.get_agent(agent_typ, client=client)
|
||||
agent.client = client
|
||||
agent.apply_config(**agent_config)
|
||||
|
||||
|
||||
|
||||
instance.emit_agents_status()
|
||||
|
||||
def init_scene(self):
|
||||
@@ -127,20 +142,21 @@ class WebsocketHandler(Receiver):
|
||||
log.debug("init agent", agent_typ=agent_typ, agent_config=agent_config)
|
||||
agent = instance.get_agent(agent_typ, **agent_config)
|
||||
|
||||
#if getattr(agent, "client", None):
|
||||
# if getattr(agent, "client", None):
|
||||
# self.llm_clients[agent.client.name] = agent.client
|
||||
|
||||
scene.add_helper(Helper(agent))
|
||||
|
||||
return scene
|
||||
|
||||
async def load_scene(self, path_or_data, reset=False, callback=None, file_name=None):
|
||||
async def load_scene(
|
||||
self, path_or_data, reset=False, callback=None, file_name=None
|
||||
):
|
||||
try:
|
||||
|
||||
if self.scene:
|
||||
instance.get_agent("memory").close_db(self.scene)
|
||||
self.scene.disconnect()
|
||||
|
||||
|
||||
scene = self.init_scene()
|
||||
|
||||
if not scene:
|
||||
@@ -172,18 +188,17 @@ class WebsocketHandler(Receiver):
|
||||
existing = set(self.llm_clients.keys())
|
||||
|
||||
self.llm_clients = {}
|
||||
|
||||
|
||||
log.info("Configuring clients", clients=clients)
|
||||
|
||||
|
||||
for client in clients:
|
||||
|
||||
client.pop("status", None)
|
||||
client_cls = CLIENT_CLASSES.get(client["type"])
|
||||
|
||||
|
||||
if not client_cls:
|
||||
log.error("Client type not found", client=client)
|
||||
continue
|
||||
|
||||
|
||||
client_config = self.llm_clients[client["name"]] = {
|
||||
"name": client["name"],
|
||||
"type": client["type"],
|
||||
@@ -208,17 +223,17 @@ class WebsocketHandler(Receiver):
|
||||
for name in removed:
|
||||
log.debug("Destroying client", name=name)
|
||||
instance.destroy_client(name)
|
||||
|
||||
|
||||
self.config["clients"] = self.llm_clients
|
||||
|
||||
self.connect_llm_clients()
|
||||
save_config(self.config)
|
||||
|
||||
|
||||
instance.sync_emit_clients_status()
|
||||
|
||||
def configure_agents(self, agents):
|
||||
self.agents = {typ: {} for typ in instance.agent_types()}
|
||||
|
||||
|
||||
log.debug("Configuring agents")
|
||||
|
||||
for agent in agents:
|
||||
@@ -235,7 +250,7 @@ class WebsocketHandler(Receiver):
|
||||
|
||||
if getattr(agent_instance, "actions", None):
|
||||
self.agents[name]["actions"] = agent.get("actions", {})
|
||||
|
||||
|
||||
agent_instance.apply_config(**self.agents[name])
|
||||
log.debug("Configured agent", name=name)
|
||||
continue
|
||||
@@ -253,16 +268,21 @@ class WebsocketHandler(Receiver):
|
||||
|
||||
agent_instance = instance.get_agent(name, **self.agents[name])
|
||||
agent_instance.client = self.llm_clients[agent["client"]]["client"]
|
||||
|
||||
|
||||
if agent_instance.has_toggle:
|
||||
self.agents[name]["enabled"] = agent["enabled"]
|
||||
|
||||
if getattr(agent_instance, "actions", None):
|
||||
self.agents[name]["actions"] = agent.get("actions", {})
|
||||
|
||||
|
||||
agent_instance.apply_config(**self.agents[name])
|
||||
|
||||
log.debug("Configured agent", name=name, client_name=self.llm_clients[agent["client"]]["name"], client=self.llm_clients[agent["client"]]["client"])
|
||||
|
||||
log.debug(
|
||||
"Configured agent",
|
||||
name=name,
|
||||
client_name=self.llm_clients[agent["client"]]["name"],
|
||||
client=self.llm_clients[agent["client"]]["client"],
|
||||
)
|
||||
|
||||
self.config["agents"] = self.agents
|
||||
save_config(self.config)
|
||||
@@ -300,16 +320,15 @@ class WebsocketHandler(Receiver):
|
||||
"character": emission.character.name if emission.character else "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def handle_director(self, emission: Emission):
|
||||
|
||||
if emission.character:
|
||||
character = emission.character.name
|
||||
elif emission.message_object.source:
|
||||
character = emission.message_object.source
|
||||
else:
|
||||
character = ""
|
||||
|
||||
|
||||
self.queue_put(
|
||||
{
|
||||
"type": "director",
|
||||
@@ -381,7 +400,7 @@ class WebsocketHandler(Receiver):
|
||||
"status": emission.status,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def handle_config_saved(self, emission: Emission):
|
||||
self.queue_put(
|
||||
{
|
||||
@@ -389,7 +408,7 @@ class WebsocketHandler(Receiver):
|
||||
"data": emission.data,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def handle_archived_history(self, emission: Emission):
|
||||
self.queue_put(
|
||||
{
|
||||
@@ -523,7 +542,6 @@ class WebsocketHandler(Receiver):
|
||||
|
||||
def request_scenes_list(self, query: str = ""):
|
||||
scenes_list = list_scenes_directory()
|
||||
|
||||
|
||||
if query:
|
||||
filtered_list = [
|
||||
@@ -547,8 +565,10 @@ class WebsocketHandler(Receiver):
|
||||
)
|
||||
|
||||
def request_scene_history(self):
|
||||
history = [archived_history["text"] for archived_history in self.scene.archived_history]
|
||||
|
||||
history = [
|
||||
archived_history["text"] for archived_history in self.scene.archived_history
|
||||
]
|
||||
|
||||
self.queue_put(
|
||||
{
|
||||
"type": "scene_history",
|
||||
@@ -558,14 +578,13 @@ class WebsocketHandler(Receiver):
|
||||
|
||||
async def request_client_status(self):
|
||||
await instance.emit_clients_status()
|
||||
|
||||
def request_scene_assets(self, asset_ids:list[str]):
|
||||
|
||||
def request_scene_assets(self, asset_ids: list[str]):
|
||||
scene_assets = self.scene.assets
|
||||
|
||||
|
||||
for asset_id in asset_ids:
|
||||
|
||||
asset = scene_assets.get_asset_bytes_as_base64(asset_id)
|
||||
|
||||
|
||||
self.queue_put(
|
||||
{
|
||||
"type": "scene_asset",
|
||||
@@ -574,16 +593,16 @@ class WebsocketHandler(Receiver):
|
||||
"media_type": scene_assets.get_asset(asset_id).media_type,
|
||||
}
|
||||
)
|
||||
|
||||
def request_assets(self, assets:list[dict]):
|
||||
|
||||
def request_assets(self, assets: list[dict]):
|
||||
# way to request scene assets without loading the scene
|
||||
#
|
||||
# assets is a list of dicts with keys:
|
||||
# path must be turned into absolute path
|
||||
# path must begin with Scene.scenes_dir()
|
||||
|
||||
|
||||
_assets = {}
|
||||
|
||||
|
||||
for asset_dict in assets:
|
||||
try:
|
||||
asset_id, asset = self._asset(**asset_dict)
|
||||
@@ -591,32 +610,37 @@ class WebsocketHandler(Receiver):
|
||||
log.error("request_assets", error=traceback.format_exc(), **asset_dict)
|
||||
continue
|
||||
_assets[asset_id] = asset
|
||||
|
||||
|
||||
self.queue_put(
|
||||
{
|
||||
"type": "assets",
|
||||
"assets": _assets,
|
||||
}
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
def _asset(self, path: str, **asset):
|
||||
absolute_path = os.path.abspath(path)
|
||||
|
||||
|
||||
if not absolute_path.startswith(Scene.scenes_dir()):
|
||||
log.error("_asset", error="Invalid path", path=absolute_path, scenes_dir=Scene.scenes_dir())
|
||||
log.error(
|
||||
"_asset",
|
||||
error="Invalid path",
|
||||
path=absolute_path,
|
||||
scenes_dir=Scene.scenes_dir(),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
asset_path = os.path.join(os.path.dirname(absolute_path), "assets")
|
||||
asset = Asset(**asset)
|
||||
return asset.id, {
|
||||
"base64": asset.to_base64(asset_path),
|
||||
"media_type": asset.media_type,
|
||||
}
|
||||
|
||||
def add_scene_asset(self, data:dict):
|
||||
|
||||
def add_scene_asset(self, data: dict):
|
||||
asset_upload = SceneAssetUpload(**data)
|
||||
asset = self.scene.assets.add_asset_from_image_data(asset_upload.content)
|
||||
|
||||
|
||||
if asset_upload.scene_cover_image:
|
||||
self.scene.assets.cover_image = asset.id
|
||||
self.scene.emit_status()
|
||||
@@ -624,9 +648,12 @@ class WebsocketHandler(Receiver):
|
||||
character = self.scene.get_character(asset_upload.character_cover_image)
|
||||
old_cover_image = character.cover_image
|
||||
character.cover_image = asset.id
|
||||
if not self.scene.assets.cover_image or old_cover_image == self.scene.assets.cover_image:
|
||||
if (
|
||||
not self.scene.assets.cover_image
|
||||
or old_cover_image == self.scene.assets.cover_image
|
||||
):
|
||||
self.scene.assets.cover_image = asset.id
|
||||
self.scene.emit_status()
|
||||
self.scene.emit_status()
|
||||
self.request_scene_assets([character.cover_image])
|
||||
self.queue_put(
|
||||
{
|
||||
@@ -637,16 +664,14 @@ class WebsocketHandler(Receiver):
|
||||
"character": character.name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def delete_message(self, message_id):
|
||||
self.scene.delete_message(message_id)
|
||||
|
||||
def edit_message(self, message_id, new_text):
|
||||
self.scene.edit_message(message_id, new_text)
|
||||
|
||||
def apply_scene_config(self, scene_config:dict):
|
||||
def apply_scene_config(self, scene_config: dict):
|
||||
self.scene.apply_scene_config(scene_config)
|
||||
self.queue_put(
|
||||
{
|
||||
@@ -654,28 +679,25 @@ class WebsocketHandler(Receiver):
|
||||
"data": self.scene.scene_config,
|
||||
}
|
||||
)
|
||||
|
||||
def handle_character_card_upload(self, image_data_url:str, filename:str) -> str:
|
||||
|
||||
|
||||
def handle_character_card_upload(self, image_data_url: str, filename: str) -> str:
|
||||
image_type = image_data_url.split(";")[0].split(":")[1]
|
||||
image_data = base64.b64decode(image_data_url.split(",")[1])
|
||||
characters_path = os.path.join("./scenes", "characters")
|
||||
|
||||
|
||||
filepath = os.path.join(characters_path, filename)
|
||||
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
async def route(self, data:dict):
|
||||
|
||||
|
||||
async def route(self, data: dict):
|
||||
route = data["type"]
|
||||
|
||||
|
||||
if route not in self.routes:
|
||||
return
|
||||
|
||||
|
||||
plugin = self.routes[route]
|
||||
try:
|
||||
await plugin.handle(data)
|
||||
@@ -687,4 +709,4 @@ class WebsocketHandler(Receiver):
|
||||
"type": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,22 +1,30 @@
|
||||
import uuid
|
||||
from typing import Any, Union
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
from typing import Union, Any
|
||||
import uuid
|
||||
|
||||
from talemate.world_state.manager import WorldStateManager, WorldStateTemplates, StateReinforcementTemplate
|
||||
from talemate.world_state.manager import (
|
||||
StateReinforcementTemplate,
|
||||
WorldStateManager,
|
||||
WorldStateTemplates,
|
||||
)
|
||||
|
||||
log = structlog.get_logger("talemate.server.world_state_manager")
|
||||
|
||||
|
||||
class UpdateCharacterAttributePayload(pydantic.BaseModel):
|
||||
name: str
|
||||
attribute: str
|
||||
value: str
|
||||
|
||||
|
||||
|
||||
class UpdateCharacterDetailPayload(pydantic.BaseModel):
|
||||
name: str
|
||||
detail: str
|
||||
value: str
|
||||
|
||||
|
||||
class SetCharacterDetailReinforcementPayload(pydantic.BaseModel):
|
||||
name: str
|
||||
question: str
|
||||
@@ -25,462 +33,545 @@ class SetCharacterDetailReinforcementPayload(pydantic.BaseModel):
|
||||
answer: str = ""
|
||||
update_state: bool = False
|
||||
insert: str = "sequential"
|
||||
|
||||
|
||||
|
||||
class CharacterDetailReinforcementPayload(pydantic.BaseModel):
|
||||
name: str
|
||||
question: str
|
||||
reset: bool = False
|
||||
|
||||
|
||||
|
||||
class SaveWorldEntryPayload(pydantic.BaseModel):
|
||||
id:str
|
||||
id: str
|
||||
text: str
|
||||
meta: dict = {}
|
||||
|
||||
|
||||
|
||||
class DeleteWorldEntryPayload(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class SetWorldEntryReinforcementPayload(pydantic.BaseModel):
|
||||
question: str
|
||||
instructions: Union[str, None] = None
|
||||
interval: int = 10
|
||||
answer: str = ""
|
||||
update_state: bool = False
|
||||
insert: str = "never"
|
||||
|
||||
insert: str = "never"
|
||||
|
||||
|
||||
class WorldEntryReinforcementPayload(pydantic.BaseModel):
|
||||
question: str
|
||||
reset: bool = False
|
||||
|
||||
|
||||
|
||||
class QueryContextDBPayload(pydantic.BaseModel):
|
||||
query: str
|
||||
meta: dict = {}
|
||||
|
||||
|
||||
class UpdateContextDBPayload(pydantic.BaseModel):
|
||||
text: str
|
||||
meta: dict = {}
|
||||
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
class DeleteContextDBPayload(pydantic.BaseModel):
|
||||
id: Any
|
||||
|
||||
|
||||
|
||||
class UpdatePinPayload(pydantic.BaseModel):
|
||||
entry_id: str
|
||||
condition: Union[str, None] = None
|
||||
condition_state: bool = False
|
||||
active: bool = False
|
||||
|
||||
|
||||
class RemovePinPayload(pydantic.BaseModel):
|
||||
entry_id: str
|
||||
entry_id: str
|
||||
|
||||
|
||||
class SaveWorldStateTemplatePayload(pydantic.BaseModel):
|
||||
template: StateReinforcementTemplate
|
||||
|
||||
|
||||
|
||||
class DeleteWorldStateTemplatePayload(pydantic.BaseModel):
|
||||
template: StateReinforcementTemplate
|
||||
|
||||
|
||||
class WorldStateManagerPlugin:
|
||||
|
||||
router = "world_state_manager"
|
||||
|
||||
|
||||
@property
|
||||
def scene(self):
|
||||
return self.websocket_handler.scene
|
||||
|
||||
|
||||
@property
|
||||
def world_state_manager(self):
|
||||
return WorldStateManager(self.scene)
|
||||
|
||||
|
||||
def __init__(self, websocket_handler):
|
||||
self.websocket_handler = websocket_handler
|
||||
|
||||
async def handle(self, data:dict):
|
||||
|
||||
|
||||
async def handle(self, data: dict):
|
||||
log.info("World state manager action", action=data.get("action"))
|
||||
|
||||
|
||||
fn = getattr(self, f"handle_{data.get('action')}", None)
|
||||
|
||||
|
||||
if fn is None:
|
||||
return
|
||||
|
||||
|
||||
await fn(data)
|
||||
|
||||
|
||||
async def signal_operation_done(self):
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "operation_done",
|
||||
"data": {}
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{"type": "world_state_manager", "action": "operation_done", "data": {}}
|
||||
)
|
||||
|
||||
if self.scene.auto_save:
|
||||
await self.scene.save(auto=True)
|
||||
|
||||
|
||||
async def handle_get_character_list(self, data):
|
||||
character_list = await self.world_state_manager.get_character_list()
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_list",
|
||||
"data": character_list.model_dump()
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_list",
|
||||
"data": character_list.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_get_character_details(self, data):
|
||||
character_details = await self.world_state_manager.get_character_details(data["name"])
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_details",
|
||||
"data": character_details.model_dump()
|
||||
})
|
||||
|
||||
character_details = await self.world_state_manager.get_character_details(
|
||||
data["name"]
|
||||
)
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_details",
|
||||
"data": character_details.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_get_world(self, data):
|
||||
world = await self.world_state_manager.get_world()
|
||||
log.debug("World", world=world)
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "world",
|
||||
"data": world.model_dump()
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "world",
|
||||
"data": world.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_get_pins(self, data):
|
||||
context_pins = await self.world_state_manager.get_pins()
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "pins",
|
||||
"data": context_pins.model_dump()
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "pins",
|
||||
"data": context_pins.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_get_templates(self, data):
|
||||
templates = await self.world_state_manager.get_templates()
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "templates",
|
||||
"data": templates.model_dump()
|
||||
})
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "templates",
|
||||
"data": templates.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
async def handle_update_character_attribute(self, data):
|
||||
|
||||
payload = UpdateCharacterAttributePayload(**data)
|
||||
|
||||
await self.world_state_manager.update_character_attribute(payload.name, payload.attribute, payload.value)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_attribute_updated",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
await self.world_state_manager.update_character_attribute(
|
||||
payload.name, payload.attribute, payload.value
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_attribute_updated",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name":payload.name})
|
||||
|
||||
await self.handle_get_character_details({"name": payload.name})
|
||||
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_update_character_description(self, data):
|
||||
|
||||
payload = UpdateCharacterAttributePayload(**data)
|
||||
|
||||
await self.world_state_manager.update_character_description(payload.name, payload.value)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_description_updated",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name":payload.name})
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_update_character_detail(self, data):
|
||||
|
||||
payload = UpdateCharacterDetailPayload(**data)
|
||||
|
||||
await self.world_state_manager.update_character_detail(payload.name, payload.detail, payload.value)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_updated",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
await self.world_state_manager.update_character_description(
|
||||
payload.name, payload.value
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_description_updated",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name":payload.name})
|
||||
await self.handle_get_character_details({"name": payload.name})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_update_character_detail(self, data):
|
||||
payload = UpdateCharacterDetailPayload(**data)
|
||||
|
||||
await self.world_state_manager.update_character_detail(
|
||||
payload.name, payload.detail, payload.value
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_updated",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name": payload.name})
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_set_character_detail_reinforcement(self, data):
|
||||
|
||||
payload = SetCharacterDetailReinforcementPayload(**data)
|
||||
|
||||
|
||||
await self.world_state_manager.add_detail_reinforcement(
|
||||
payload.name,
|
||||
payload.question,
|
||||
payload.instructions,
|
||||
payload.interval,
|
||||
payload.name,
|
||||
payload.question,
|
||||
payload.instructions,
|
||||
payload.interval,
|
||||
payload.answer,
|
||||
payload.insert,
|
||||
payload.update_state
|
||||
payload.update_state,
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_reinforcement_set",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_reinforcement_set",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name":payload.name})
|
||||
await self.handle_get_character_details({"name": payload.name})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_run_character_detail_reinforcement(self, data):
|
||||
|
||||
payload = CharacterDetailReinforcementPayload(**data)
|
||||
|
||||
log.debug("Run character detail reinforcement", name=payload.name, question=payload.question, reset=payload.reset)
|
||||
|
||||
await self.world_state_manager.run_detail_reinforcement(
|
||||
payload.name,
|
||||
payload.question,
|
||||
reset=payload.reset
|
||||
|
||||
log.debug(
|
||||
"Run character detail reinforcement",
|
||||
name=payload.name,
|
||||
question=payload.question,
|
||||
reset=payload.reset,
|
||||
)
|
||||
|
||||
await self.world_state_manager.run_detail_reinforcement(
|
||||
payload.name, payload.question, reset=payload.reset
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_reinforcement_run",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_reinforcement_run",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name":payload.name})
|
||||
await self.handle_get_character_details({"name": payload.name})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_delete_character_detail_reinforcement(self, data):
|
||||
|
||||
payload = CharacterDetailReinforcementPayload(**data)
|
||||
|
||||
await self.world_state_manager.delete_detail_reinforcement(payload.name, payload.question)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_reinforcement_deleted",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name":payload.name})
|
||||
await self.signal_operation_done()
|
||||
|
||||
await self.world_state_manager.delete_detail_reinforcement(
|
||||
payload.name, payload.question
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "character_detail_reinforcement_deleted",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
# resend character details
|
||||
await self.handle_get_character_details({"name": payload.name})
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_save_world_entry(self, data):
|
||||
|
||||
payload = SaveWorldEntryPayload(**data)
|
||||
|
||||
log.debug("Save world entry", id=payload.id, text=payload.text, meta=payload.meta)
|
||||
|
||||
await self.world_state_manager.save_world_entry(payload.id, payload.text, payload.meta)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "world_entry_saved",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
log.debug(
|
||||
"Save world entry", id=payload.id, text=payload.text, meta=payload.meta
|
||||
)
|
||||
|
||||
await self.world_state_manager.save_world_entry(
|
||||
payload.id, payload.text, payload.meta
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "world_entry_saved",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.handle_get_world({})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
self.scene.world_state.emit()
|
||||
|
||||
|
||||
async def handle_delete_world_entry(self, data):
|
||||
|
||||
payload = DeleteWorldEntryPayload(**data)
|
||||
|
||||
|
||||
log.debug("Delete world entry", id=payload.id)
|
||||
|
||||
|
||||
await self.world_state_manager.delete_context_db_entry(payload.id)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "world_entry_deleted",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "world_entry_deleted",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.handle_get_world({})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
self.scene.world_state.emit()
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
async def handle_set_world_state_reinforcement(self, data):
|
||||
|
||||
payload = SetWorldEntryReinforcementPayload(**data)
|
||||
|
||||
|
||||
log.debug("Set world state reinforcement", question=payload.question, instructions=payload.instructions, interval=payload.interval, answer=payload.answer, insert=payload.insert, update_state=payload.update_state)
|
||||
|
||||
|
||||
log.debug(
|
||||
"Set world state reinforcement",
|
||||
question=payload.question,
|
||||
instructions=payload.instructions,
|
||||
interval=payload.interval,
|
||||
answer=payload.answer,
|
||||
insert=payload.insert,
|
||||
update_state=payload.update_state,
|
||||
)
|
||||
|
||||
await self.world_state_manager.add_detail_reinforcement(
|
||||
None,
|
||||
payload.question,
|
||||
payload.instructions,
|
||||
payload.interval,
|
||||
payload.question,
|
||||
payload.instructions,
|
||||
payload.interval,
|
||||
payload.answer,
|
||||
payload.insert,
|
||||
payload.update_state
|
||||
payload.update_state,
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "world_state_reinforcement_set",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "world_state_reinforcement_set",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
# resend world
|
||||
await self.handle_get_world({})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_run_world_state_reinforcement(self, data):
|
||||
payload = WorldEntryReinforcementPayload(**data)
|
||||
|
||||
await self.world_state_manager.run_detail_reinforcement(None, payload.question, payload.reset)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "world_state_reinforcement_ran",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
await self.world_state_manager.run_detail_reinforcement(
|
||||
None, payload.question, payload.reset
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "world_state_reinforcement_ran",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
# resend world
|
||||
await self.handle_get_world({})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_delete_world_state_reinforcement(self, data):
|
||||
|
||||
payload = WorldEntryReinforcementPayload(**data)
|
||||
|
||||
await self.world_state_manager.delete_detail_reinforcement(None, payload.question)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "world_state_reinforcement_deleted",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
await self.world_state_manager.delete_detail_reinforcement(
|
||||
None, payload.question
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "world_state_reinforcement_deleted",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
# resend world
|
||||
await self.handle_get_world({})
|
||||
await self.signal_operation_done()
|
||||
|
||||
async def handle_query_context_db(self, data):
|
||||
|
||||
payload = QueryContextDBPayload(**data)
|
||||
|
||||
|
||||
log.debug("Query context db", query=payload.query, meta=payload.meta)
|
||||
|
||||
context_db = await self.world_state_manager.get_context_db_entries(payload.query, **payload.meta)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "context_db_result",
|
||||
"data": context_db.model_dump()
|
||||
})
|
||||
|
||||
|
||||
context_db = await self.world_state_manager.get_context_db_entries(
|
||||
payload.query, **payload.meta
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "context_db_result",
|
||||
"data": context_db.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_update_context_db(self, data):
|
||||
|
||||
payload = UpdateContextDBPayload(**data)
|
||||
|
||||
log.debug("Update context db", text=payload.text, meta=payload.meta, id=payload.id)
|
||||
|
||||
await self.world_state_manager.update_context_db_entry(payload.id, payload.text, payload.meta)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "context_db_updated",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
log.debug(
|
||||
"Update context db", text=payload.text, meta=payload.meta, id=payload.id
|
||||
)
|
||||
|
||||
await self.world_state_manager.update_context_db_entry(
|
||||
payload.id, payload.text, payload.meta
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "context_db_updated",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_delete_context_db(self, data):
|
||||
|
||||
payload = DeleteContextDBPayload(**data)
|
||||
|
||||
|
||||
log.debug("Delete context db", id=payload.id)
|
||||
|
||||
|
||||
await self.world_state_manager.delete_context_db_entry(payload.id)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "context_db_deleted",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "context_db_deleted",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_set_pin(self, data):
|
||||
|
||||
payload = UpdatePinPayload(**data)
|
||||
|
||||
log.debug("Set pin", entry_id=payload.entry_id, condition=payload.condition, condition_state=payload.condition_state, active=payload.active)
|
||||
|
||||
await self.world_state_manager.set_pin(payload.entry_id, payload.condition, payload.condition_state, payload.active)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "pin_set",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
log.debug(
|
||||
"Set pin",
|
||||
entry_id=payload.entry_id,
|
||||
condition=payload.condition,
|
||||
condition_state=payload.condition_state,
|
||||
active=payload.active,
|
||||
)
|
||||
|
||||
await self.world_state_manager.set_pin(
|
||||
payload.entry_id, payload.condition, payload.condition_state, payload.active
|
||||
)
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "pin_set",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.handle_get_pins({})
|
||||
await self.signal_operation_done()
|
||||
await self.scene.load_active_pins()
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
async def handle_remove_pin(self, data):
|
||||
|
||||
payload = RemovePinPayload(**data)
|
||||
|
||||
|
||||
log.debug("Remove pin", entry_id=payload.entry_id)
|
||||
|
||||
|
||||
await self.world_state_manager.remove_pin(payload.entry_id)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "pin_removed",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "pin_removed",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.handle_get_pins({})
|
||||
await self.signal_operation_done()
|
||||
await self.scene.load_active_pins()
|
||||
self.scene.emit_status()
|
||||
|
||||
|
||||
async def handle_save_template(self, data):
|
||||
|
||||
payload = SaveWorldStateTemplatePayload(**data)
|
||||
|
||||
|
||||
log.debug("Save world state template", template=payload.template)
|
||||
|
||||
|
||||
await self.world_state_manager.save_template(payload.template)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "template_saved",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "template_saved",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.handle_get_templates({})
|
||||
await self.signal_operation_done()
|
||||
|
||||
|
||||
async def handle_delete_template(self, data):
|
||||
|
||||
payload = DeleteWorldStateTemplatePayload(**data)
|
||||
template = payload.template
|
||||
|
||||
log.debug("Delete world state template", template=template.name, template_type=template.type)
|
||||
|
||||
|
||||
log.debug(
|
||||
"Delete world state template",
|
||||
template=template.name,
|
||||
template_type=template.type,
|
||||
)
|
||||
|
||||
await self.world_state_manager.remove_template(template.type, template.name)
|
||||
|
||||
self.websocket_handler.queue_put({
|
||||
"type": "world_state_manager",
|
||||
"action": "template_deleted",
|
||||
"data": payload.model_dump()
|
||||
})
|
||||
|
||||
|
||||
self.websocket_handler.queue_put(
|
||||
{
|
||||
"type": "world_state_manager",
|
||||
"action": "template_deleted",
|
||||
"data": payload.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
await self.handle_get_templates({})
|
||||
await self.signal_operation_done()
|
||||
await self.signal_operation_done()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from talemate.emit import emit
|
||||
import structlog
|
||||
|
||||
from talemate.emit import emit
|
||||
|
||||
__all__ = [
|
||||
"set_loading",
|
||||
"LoadingStatus",
|
||||
@@ -10,11 +11,10 @@ log = structlog.get_logger("talemate.status")
|
||||
|
||||
|
||||
class set_loading:
|
||||
|
||||
def __init__(self, message, set_busy:bool=True):
|
||||
def __init__(self, message, set_busy: bool = True):
|
||||
self.message = message
|
||||
self.set_busy = set_busy
|
||||
|
||||
|
||||
def __call__(self, fn):
|
||||
async def wrapper(*args, **kwargs):
|
||||
if self.set_busy:
|
||||
@@ -23,15 +23,19 @@ class set_loading:
|
||||
return await fn(*args, **kwargs)
|
||||
finally:
|
||||
emit("status", message="", status="idle")
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class LoadingStatus:
|
||||
|
||||
def __init__(self, max_steps:int):
|
||||
def __init__(self, max_steps: int):
|
||||
self.max_steps = max_steps
|
||||
self.current_step = 0
|
||||
|
||||
def __call__(self, message:str):
|
||||
|
||||
def __call__(self, message: str):
|
||||
self.current_step += 1
|
||||
emit("status", message=f"{message} [{self.current_step}/{self.max_steps}]", status="busy")
|
||||
emit(
|
||||
"status",
|
||||
message=f"{message} [{self.current_step}/{self.max_steps}]",
|
||||
status="busy",
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,26 @@
|
||||
import base64
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import textwrap
|
||||
import structlog
|
||||
import isodate
|
||||
import datetime
|
||||
from typing import List, Union
|
||||
from thefuzz import fuzz
|
||||
|
||||
import isodate
|
||||
import structlog
|
||||
from colorama import Back, Fore, Style, init
|
||||
from PIL import Image
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from PIL import Image
|
||||
from thefuzz import fuzz
|
||||
|
||||
from talemate.scene_message import SceneMessage
|
||||
|
||||
log = structlog.get_logger("talemate.util")
|
||||
|
||||
# Initialize colorama
|
||||
init(autoreset=True)
|
||||
|
||||
|
||||
def fix_unquoted_keys(s):
|
||||
unquoted_key_pattern = r"(?<!\\)(?:(?<=\{)|(?<=,))\s*(\w+)\s*:"
|
||||
fixed_string = re.sub(
|
||||
@@ -279,25 +282,25 @@ def replace_conditional(input_string: str, params) -> str:
|
||||
return modified_string
|
||||
|
||||
|
||||
def strip_partial_sentences(text:str) -> str:
|
||||
def strip_partial_sentences(text: str) -> str:
|
||||
# Sentence ending characters
|
||||
sentence_endings = ['.', '!', '?', '"', "*"]
|
||||
|
||||
sentence_endings = [".", "!", "?", '"', "*"]
|
||||
|
||||
if not text:
|
||||
return text
|
||||
|
||||
|
||||
# Check if the last character is already a sentence ending
|
||||
if text[-1] in sentence_endings:
|
||||
return text
|
||||
|
||||
|
||||
# Split the text into words
|
||||
words = text.split()
|
||||
|
||||
|
||||
# Iterate over the words in reverse order until a sentence ending is found
|
||||
for i in range(len(words) - 1, -1, -1):
|
||||
if words[i][-1] in sentence_endings:
|
||||
return ' '.join(words[:i+1])
|
||||
|
||||
return " ".join(words[: i + 1])
|
||||
|
||||
# If no sentence ending is found, return the original text
|
||||
return text
|
||||
|
||||
@@ -335,8 +338,8 @@ def clean_message(message: str) -> str:
|
||||
message = message.replace("[", "*").replace("]", "*")
|
||||
return message
|
||||
|
||||
|
||||
def clean_dialogue(dialogue: str, main_name: str) -> str:
|
||||
|
||||
# re split by \n{not main_name}: with a max count of 1
|
||||
pattern = r"\n(?!{}:).*".format(re.escape(main_name))
|
||||
|
||||
@@ -346,10 +349,11 @@ def clean_dialogue(dialogue: str, main_name: str) -> str:
|
||||
dialogue = f"{main_name}: {dialogue}"
|
||||
|
||||
return clean_message(strip_partial_sentences(dialogue))
|
||||
|
||||
|
||||
|
||||
def clean_id(name: str) -> str:
|
||||
"""
|
||||
Cleans up a id name by removing all characters that aren't a-zA-Z0-9_-
|
||||
Cleans up a id name by removing all characters that aren't a-zA-Z0-9_-
|
||||
|
||||
Spaces are allowed.
|
||||
|
||||
@@ -361,9 +365,10 @@ def clean_id(name: str) -> str:
|
||||
"""
|
||||
# Remove all characters that aren't a-zA-Z0-9_-
|
||||
cleaned_name = re.sub(r"[^a-zA-Z0-9_\- ]", "", name)
|
||||
|
||||
|
||||
return cleaned_name
|
||||
|
||||
|
||||
def duration_to_timedelta(duration):
|
||||
"""Convert an isodate.Duration object or a datetime.timedelta object to a datetime.timedelta object."""
|
||||
# Check if the duration is already a timedelta object
|
||||
@@ -375,6 +380,7 @@ def duration_to_timedelta(duration):
|
||||
seconds = duration.tdelta.seconds
|
||||
return datetime.timedelta(days=days, seconds=seconds)
|
||||
|
||||
|
||||
def timedelta_to_duration(delta):
|
||||
"""Convert a datetime.timedelta object to an isodate.Duration object."""
|
||||
# Extract days and convert to years, months, and days
|
||||
@@ -389,7 +395,15 @@ def timedelta_to_duration(delta):
|
||||
seconds %= 3600
|
||||
minutes = seconds // 60
|
||||
seconds %= 60
|
||||
return isodate.Duration(years=years, months=months, days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
return isodate.Duration(
|
||||
years=years,
|
||||
months=months,
|
||||
days=days,
|
||||
hours=hours,
|
||||
minutes=minutes,
|
||||
seconds=seconds,
|
||||
)
|
||||
|
||||
|
||||
def parse_duration_to_isodate_duration(duration_str):
|
||||
"""Parse ISO 8601 duration string and ensure the result is an isodate.Duration."""
|
||||
@@ -398,6 +412,7 @@ def parse_duration_to_isodate_duration(duration_str):
|
||||
return timedelta_to_duration(parsed_duration)
|
||||
return parsed_duration
|
||||
|
||||
|
||||
def iso8601_diff(duration_str1, duration_str2):
|
||||
# Parse the ISO 8601 duration strings ensuring they are isodate.Duration objects
|
||||
duration1 = parse_duration_to_isodate_duration(duration_str1)
|
||||
@@ -409,14 +424,14 @@ def iso8601_diff(duration_str1, duration_str2):
|
||||
|
||||
# Calculate the difference
|
||||
difference_timedelta = abs(timedelta1 - timedelta2)
|
||||
|
||||
|
||||
# Convert back to Duration for further processing
|
||||
difference = timedelta_to_duration(difference_timedelta)
|
||||
|
||||
return difference
|
||||
|
||||
|
||||
def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
|
||||
|
||||
# Parse the ISO8601 duration string into an isodate duration object
|
||||
if not isinstance(iso_duration, isodate.Duration):
|
||||
duration = isodate.parse_duration(iso_duration)
|
||||
@@ -463,7 +478,7 @@ def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
|
||||
# Construct the human-readable string
|
||||
if len(components) > 1:
|
||||
last = components.pop()
|
||||
human_str = ', '.join(components) + ' and ' + last
|
||||
human_str = ", ".join(components) + " and " + last
|
||||
elif components:
|
||||
human_str = components[0]
|
||||
else:
|
||||
@@ -471,16 +486,17 @@ def iso8601_duration_to_human(iso_duration, suffix: str = " ago"):
|
||||
|
||||
return f"{human_str}{suffix}"
|
||||
|
||||
|
||||
def iso8601_diff_to_human(start, end):
|
||||
if not start or not end:
|
||||
return ""
|
||||
|
||||
|
||||
diff = iso8601_diff(start, end)
|
||||
|
||||
|
||||
return iso8601_duration_to_human(diff)
|
||||
|
||||
|
||||
def iso8601_add(date_a:str, date_b:str) -> str:
|
||||
def iso8601_add(date_a: str, date_b: str) -> str:
|
||||
"""
|
||||
Adds two ISO 8601 durations together.
|
||||
"""
|
||||
@@ -488,80 +504,84 @@ def iso8601_add(date_a:str, date_b:str) -> str:
|
||||
if not date_a or not date_b:
|
||||
return "PT0S"
|
||||
|
||||
new_ts = isodate.parse_duration(date_a.strip()) + isodate.parse_duration(date_b.strip())
|
||||
new_ts = isodate.parse_duration(date_a.strip()) + isodate.parse_duration(
|
||||
date_b.strip()
|
||||
)
|
||||
return isodate.duration_isoformat(new_ts)
|
||||
|
||||
|
||||
def iso8601_correct_duration(duration: str) -> str:
|
||||
# Split the string into date and time components using 'T' as the delimiter
|
||||
parts = duration.split("T")
|
||||
|
||||
|
||||
# Handle the date component
|
||||
date_component = parts[0]
|
||||
time_component = ""
|
||||
|
||||
|
||||
# If there's a time component, process it
|
||||
if len(parts) > 1:
|
||||
time_component = parts[1]
|
||||
|
||||
|
||||
# Check if the time component has any date values (Y, M, D) and move them to the date component
|
||||
for char in "YD": # Removed 'M' from this loop
|
||||
if char in time_component:
|
||||
index = time_component.index(char)
|
||||
date_component += time_component[:index+1]
|
||||
time_component = time_component[index+1:]
|
||||
|
||||
date_component += time_component[: index + 1]
|
||||
time_component = time_component[index + 1 :]
|
||||
|
||||
# If the date component contains any time values (H, M, S), move them to the time component
|
||||
for char in "HMS":
|
||||
if char in date_component:
|
||||
index = date_component.index(char)
|
||||
time_component = date_component[index:] + time_component
|
||||
date_component = date_component[:index]
|
||||
|
||||
|
||||
# Combine the corrected date and time components
|
||||
corrected_duration = date_component
|
||||
if time_component:
|
||||
corrected_duration += "T" + time_component
|
||||
|
||||
|
||||
return corrected_duration
|
||||
|
||||
|
||||
def fix_faulty_json(data: str) -> str:
|
||||
# Fix missing commas
|
||||
data = re.sub(r'}\s*{', '},{', data)
|
||||
data = re.sub(r']\s*{', '],{', data)
|
||||
data = re.sub(r'}\s*\[', '},{', data)
|
||||
data = re.sub(r']\s*\[', '],[', data)
|
||||
|
||||
data = re.sub(r"}\s*{", "},{", data)
|
||||
data = re.sub(r"]\s*{", "],{", data)
|
||||
data = re.sub(r"}\s*\[", "},{", data)
|
||||
data = re.sub(r"]\s*\[", "],[", data)
|
||||
|
||||
# Fix trailing commas
|
||||
data = re.sub(r',\s*}', '}', data)
|
||||
data = re.sub(r',\s*]', ']', data)
|
||||
|
||||
data = re.sub(r",\s*}", "}", data)
|
||||
data = re.sub(r",\s*]", "]", data)
|
||||
|
||||
try:
|
||||
json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
json.loads(data+"}")
|
||||
return data+"}"
|
||||
except json.JSONDecodeError:
|
||||
json.loads(data + "}")
|
||||
return data + "}"
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
json.loads(data+"]")
|
||||
return data+"]"
|
||||
json.loads(data + "]")
|
||||
return data + "]"
|
||||
except json.JSONDecodeError:
|
||||
return data
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def extract_json(s):
|
||||
"""
|
||||
Extracts a JSON string from the beginning of the input string `s`.
|
||||
|
||||
|
||||
Parameters:
|
||||
s (str): The input string containing a JSON string at the beginning.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The extracted JSON string.
|
||||
dict: The parsed JSON object.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If a valid JSON string is not found.
|
||||
"""
|
||||
@@ -571,138 +591,154 @@ def extract_json(s):
|
||||
json_string_start = None
|
||||
s = s.lstrip() # Strip white spaces and line breaks from the beginning
|
||||
i = 0
|
||||
|
||||
|
||||
log.debug("extract_json", s=s)
|
||||
|
||||
|
||||
# Iterate through the string.
|
||||
while i < len(s):
|
||||
# Count the opening and closing curly brackets.
|
||||
if s[i] == '{' or s[i] == '[':
|
||||
if s[i] == "{" or s[i] == "[":
|
||||
bracket_stack.append(s[i])
|
||||
open_brackets += 1
|
||||
if json_string_start is None:
|
||||
json_string_start = i
|
||||
elif s[i] == '}' or s[i] == ']':
|
||||
elif s[i] == "}" or s[i] == "]":
|
||||
bracket_stack
|
||||
close_brackets += 1
|
||||
# Check if the brackets match, indicating a complete JSON string.
|
||||
if open_brackets == close_brackets:
|
||||
json_string = s[json_string_start:i+1]
|
||||
json_string = s[json_string_start : i + 1]
|
||||
# Try to parse the JSON string.
|
||||
return json_string, json.loads(json_string)
|
||||
i += 1
|
||||
|
||||
|
||||
if json_string_start is None:
|
||||
raise ValueError("No JSON string found.")
|
||||
|
||||
|
||||
json_string = s[json_string_start:]
|
||||
while bracket_stack:
|
||||
char = bracket_stack.pop()
|
||||
if char == '{':
|
||||
json_string += '}'
|
||||
elif char == '[':
|
||||
json_string += ']'
|
||||
|
||||
if char == "{":
|
||||
json_string += "}"
|
||||
elif char == "[":
|
||||
json_string += "]"
|
||||
|
||||
json_object = json.loads(json_string)
|
||||
return json_string, json_object
|
||||
|
||||
def similarity_score(line: str, lines: list[str], similarity_threshold: int = 95) -> tuple[bool, int, str]:
|
||||
|
||||
def similarity_score(
|
||||
line: str, lines: list[str], similarity_threshold: int = 95
|
||||
) -> tuple[bool, int, str]:
|
||||
"""
|
||||
Checks if a line is similar to any of the lines in the list of lines.
|
||||
|
||||
|
||||
Arguments:
|
||||
line (str): The line to check.
|
||||
lines (list): The list of lines to check against.
|
||||
threshold (int): The similarity threshold to use when comparing lines.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: Whether a similar line was found.
|
||||
int: The similarity score of the line. If no similar line was found, the highest similarity score is returned.
|
||||
str: The similar line that was found. If no similar line was found, None is returned.
|
||||
"""
|
||||
|
||||
|
||||
highest_similarity = 0
|
||||
|
||||
|
||||
for existing_line in lines:
|
||||
similarity = fuzz.ratio(line, existing_line)
|
||||
highest_similarity = max(highest_similarity, similarity)
|
||||
#print("SIMILARITY", similarity, existing_line[:32]+"...")
|
||||
# print("SIMILARITY", similarity, existing_line[:32]+"...")
|
||||
if similarity >= similarity_threshold:
|
||||
return True, similarity, existing_line
|
||||
|
||||
|
||||
return False, highest_similarity, None
|
||||
|
||||
def dedupe_sentences(line_a:str, line_b:str, similarity_threshold:int=95, debug:bool=False, split_on_comma:bool=True) -> str:
|
||||
|
||||
def dedupe_sentences(
|
||||
line_a: str,
|
||||
line_b: str,
|
||||
similarity_threshold: int = 95,
|
||||
debug: bool = False,
|
||||
split_on_comma: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Will split both lines into sentences and then compare each sentence in line_a
|
||||
against similar sentences in line_b. If a similar sentence is found, it will be
|
||||
removed from line_a.
|
||||
|
||||
|
||||
The similarity threshold is used to determine if two sentences are similar.
|
||||
|
||||
|
||||
Arguments:
|
||||
line_a (str): The first line.
|
||||
line_b (str): The second line.
|
||||
similarity_threshold (int): The similarity threshold to use when comparing sentences.
|
||||
debug (bool): Whether to log debug messages.
|
||||
split_on_comma (bool): Whether to split line_b sentences on commas as well.
|
||||
|
||||
|
||||
Returns:
|
||||
str: the cleaned line_a.
|
||||
"""
|
||||
|
||||
|
||||
line_a_sentences = sent_tokenize(line_a)
|
||||
line_b_sentences = sent_tokenize(line_b)
|
||||
|
||||
|
||||
cleaned_line_a_sentences = []
|
||||
|
||||
|
||||
if split_on_comma:
|
||||
# collect all sentences from line_b that contain a comma
|
||||
line_b_sentences_with_comma = []
|
||||
for line_b_sentence in line_b_sentences:
|
||||
if "," in line_b_sentence:
|
||||
line_b_sentences_with_comma.append(line_b_sentence)
|
||||
|
||||
|
||||
# then split all sentences in line_b_sentences_with_comma on the comma
|
||||
# and extend line_b_sentences with the split sentences, making sure
|
||||
# to strip whitespace from the beginning and end of each sentence
|
||||
|
||||
|
||||
for line_b_sentence in line_b_sentences_with_comma:
|
||||
line_b_sentences.extend([s.strip() for s in line_b_sentence.split(",")])
|
||||
|
||||
|
||||
|
||||
for line_a_sentence in line_a_sentences:
|
||||
similar_found = False
|
||||
for line_b_sentence in line_b_sentences:
|
||||
similarity = fuzz.ratio(line_a_sentence, line_b_sentence)
|
||||
if similarity >= similarity_threshold:
|
||||
if debug:
|
||||
log.debug("DEDUPE SENTENCE", similarity=similarity, line_a_sentence=line_a_sentence, line_b_sentence=line_b_sentence)
|
||||
log.debug(
|
||||
"DEDUPE SENTENCE",
|
||||
similarity=similarity,
|
||||
line_a_sentence=line_a_sentence,
|
||||
line_b_sentence=line_b_sentence,
|
||||
)
|
||||
similar_found = True
|
||||
break
|
||||
if not similar_found:
|
||||
cleaned_line_a_sentences.append(line_a_sentence)
|
||||
|
||||
|
||||
return " ".join(cleaned_line_a_sentences)
|
||||
|
||||
def dedupe_string_old(s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False) -> str:
|
||||
|
||||
|
||||
def dedupe_string_old(
|
||||
s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Removes duplicate lines from a string.
|
||||
|
||||
|
||||
Arguments:
|
||||
s (str): The input string.
|
||||
min_length (int): The minimum length of a line to be checked for duplicates.
|
||||
similarity_threshold (int): The similarity threshold to use when comparing lines.
|
||||
debug (bool): Whether to log debug messages.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The deduplicated string.
|
||||
"""
|
||||
|
||||
|
||||
lines = s.split("\n")
|
||||
deduped = []
|
||||
|
||||
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
if len(stripped_line) > min_length:
|
||||
@@ -712,33 +748,40 @@ def dedupe_string_old(s: str, min_length: int = 32, similarity_threshold: int =
|
||||
if similarity >= similarity_threshold:
|
||||
similar_found = True
|
||||
if debug:
|
||||
log.debug("DEDUPE", similarity=similarity, line=line, existing_line=existing_line)
|
||||
log.debug(
|
||||
"DEDUPE",
|
||||
similarity=similarity,
|
||||
line=line,
|
||||
existing_line=existing_line,
|
||||
)
|
||||
break
|
||||
if not similar_found:
|
||||
deduped.append(line)
|
||||
else:
|
||||
deduped.append(line) # Allow shorter strings without dupe check
|
||||
|
||||
|
||||
return "\n".join(deduped)
|
||||
|
||||
def dedupe_string(s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False) -> str:
|
||||
|
||||
|
||||
def dedupe_string(
|
||||
s: str, min_length: int = 32, similarity_threshold: int = 95, debug: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Removes duplicate lines from a string going from the bottom up.
|
||||
|
||||
|
||||
Arguments:
|
||||
s (str): The input string.
|
||||
min_length (int): The minimum length of a line to be checked for duplicates.
|
||||
similarity_threshold (int): The similarity threshold to use when comparing lines.
|
||||
debug (bool): Whether to log debug messages.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The deduplicated string.
|
||||
"""
|
||||
|
||||
|
||||
lines = s.split("\n")
|
||||
deduped = []
|
||||
|
||||
|
||||
for line in reversed(lines):
|
||||
stripped_line = line.strip()
|
||||
if len(stripped_line) > min_length:
|
||||
@@ -748,36 +791,42 @@ def dedupe_string(s: str, min_length: int = 32, similarity_threshold: int = 95,
|
||||
if similarity >= similarity_threshold:
|
||||
similar_found = True
|
||||
if debug:
|
||||
log.debug("DEDUPE", similarity=similarity, line=line, existing_line=existing_line)
|
||||
log.debug(
|
||||
"DEDUPE",
|
||||
similarity=similarity,
|
||||
line=line,
|
||||
existing_line=existing_line,
|
||||
)
|
||||
break
|
||||
if not similar_found:
|
||||
deduped.append(line)
|
||||
else:
|
||||
deduped.append(line) # Allow shorter strings without dupe check
|
||||
|
||||
|
||||
return "\n".join(reversed(deduped))
|
||||
|
||||
|
||||
def remove_extra_linebreaks(s: str) -> str:
|
||||
"""
|
||||
Removes extra line breaks from a string.
|
||||
|
||||
|
||||
Parameters:
|
||||
s (str): The input string.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The string with extra line breaks removed.
|
||||
"""
|
||||
return re.sub(r"\n{3,}", "\n\n", s)
|
||||
|
||||
def replace_exposition_markers(s:str) -> str:
|
||||
|
||||
def replace_exposition_markers(s: str) -> str:
|
||||
s = s.replace("(", "*").replace(")", "*")
|
||||
s = s.replace("[", "*").replace("]", "*")
|
||||
return s
|
||||
return s
|
||||
|
||||
|
||||
def ensure_dialog_format(line:str, talking_character:str=None) -> str:
|
||||
|
||||
#if "*" not in line and '"' not in line:
|
||||
def ensure_dialog_format(line: str, talking_character: str = None) -> str:
|
||||
# if "*" not in line and '"' not in line:
|
||||
# if talking_character:
|
||||
# line = line[len(talking_character)+1:].lstrip()
|
||||
# return f"{talking_character}: \"{line}\""
|
||||
@@ -785,7 +834,7 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
|
||||
#
|
||||
|
||||
if talking_character:
|
||||
line = line[len(talking_character)+1:].lstrip()
|
||||
line = line[len(talking_character) + 1 :].lstrip()
|
||||
|
||||
lines = []
|
||||
|
||||
@@ -793,52 +842,53 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
|
||||
try:
|
||||
_line = ensure_dialog_line_format(_line)
|
||||
except Exception as exc:
|
||||
log.error("ensure_dialog_format", msg="Error ensuring dialog line format", line=_line, exc_info=exc)
|
||||
log.error(
|
||||
"ensure_dialog_format",
|
||||
msg="Error ensuring dialog line format",
|
||||
line=_line,
|
||||
exc_info=exc,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
lines.append(_line)
|
||||
|
||||
|
||||
if len(lines) > 1:
|
||||
line = "\n".join(lines)
|
||||
else:
|
||||
line = lines[0]
|
||||
|
||||
|
||||
if talking_character:
|
||||
line = f"{talking_character}: {line}"
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def ensure_dialog_line_format(line:str):
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def ensure_dialog_line_format(line: str):
|
||||
"""
|
||||
a Python function that standardizes the formatting of dialogue and action/thought
|
||||
descriptions in text strings. This function is intended for use in a text-based
|
||||
a Python function that standardizes the formatting of dialogue and action/thought
|
||||
descriptions in text strings. This function is intended for use in a text-based
|
||||
game where spoken dialogue is encased in double quotes (" ") and actions/thoughts are
|
||||
encased in asterisks (* *). The function must correctly format strings, ensuring that
|
||||
each spoken sentence and action/thought is properly encased
|
||||
"""
|
||||
|
||||
|
||||
|
||||
i = 0
|
||||
|
||||
|
||||
segments = []
|
||||
segment = None
|
||||
segment_open = None
|
||||
|
||||
|
||||
line = line.strip()
|
||||
|
||||
|
||||
line = line.replace('"*', '"').replace('*"', '"')
|
||||
|
||||
|
||||
for i in range(len(line)):
|
||||
|
||||
|
||||
c = line[i]
|
||||
|
||||
#print("segment_open", segment_open)
|
||||
#print("segment", segment)
|
||||
|
||||
if c in ['"', '*']:
|
||||
|
||||
# print("segment_open", segment_open)
|
||||
# print("segment", segment)
|
||||
|
||||
if c in ['"', "*"]:
|
||||
if segment_open == c:
|
||||
# open segment is the same as the current character
|
||||
# closing
|
||||
@@ -849,15 +899,15 @@ def ensure_dialog_line_format(line:str):
|
||||
elif segment_open is not None and segment_open != c:
|
||||
# open segment is not the same as the current character
|
||||
# opening - close the current segment and open a new one
|
||||
|
||||
|
||||
# if we are at the last character we append the segment
|
||||
if i == len(line)-1 and segment.strip():
|
||||
if i == len(line) - 1 and segment.strip():
|
||||
segment += c
|
||||
segments += [segment.strip()]
|
||||
segment_open = None
|
||||
segment = None
|
||||
continue
|
||||
|
||||
|
||||
segments += [segment.strip()]
|
||||
segment_open = c
|
||||
segment = c
|
||||
@@ -871,27 +921,27 @@ def ensure_dialog_line_format(line:str):
|
||||
segment = c
|
||||
else:
|
||||
segment += c
|
||||
|
||||
|
||||
if segment is not None:
|
||||
if segment.strip().strip("*").strip('"'):
|
||||
segments += [segment.strip()]
|
||||
|
||||
|
||||
for i in range(len(segments)):
|
||||
segment = segments[i]
|
||||
if segment in ['"', '*']:
|
||||
if segment in ['"', "*"]:
|
||||
if i > 0:
|
||||
prev_segment = segments[i-1]
|
||||
if prev_segment and prev_segment[-1] not in ['"', '*']:
|
||||
segments[i-1] = f"{prev_segment}{segment}"
|
||||
prev_segment = segments[i - 1]
|
||||
if prev_segment and prev_segment[-1] not in ['"', "*"]:
|
||||
segments[i - 1] = f"{prev_segment}{segment}"
|
||||
segments[i] = ""
|
||||
continue
|
||||
|
||||
|
||||
for i in range(len(segments)):
|
||||
segment = segments[i]
|
||||
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
|
||||
if segment[0] == "*" and segment[-1] != "*":
|
||||
segment += "*"
|
||||
elif segment[-1] == "*" and segment[0] != "*":
|
||||
@@ -900,43 +950,42 @@ def ensure_dialog_line_format(line:str):
|
||||
segment += '"'
|
||||
elif segment[-1] == '"' and segment[0] != '"':
|
||||
segment = '"' + segment
|
||||
elif segment[0] in ['"', '*'] and segment[-1] == segment[0]:
|
||||
elif segment[0] in ['"', "*"] and segment[-1] == segment[0]:
|
||||
continue
|
||||
|
||||
|
||||
segments[i] = segment
|
||||
|
||||
|
||||
for i in range(len(segments)):
|
||||
segment = segments[i]
|
||||
if not segment or segment[0] in ['"', '*']:
|
||||
if not segment or segment[0] in ['"', "*"]:
|
||||
continue
|
||||
|
||||
prev_segment = segments[i-1] if i > 0 else None
|
||||
next_segment = segments[i+1] if i < len(segments)-1 else None
|
||||
|
||||
|
||||
prev_segment = segments[i - 1] if i > 0 else None
|
||||
next_segment = segments[i + 1] if i < len(segments) - 1 else None
|
||||
|
||||
if prev_segment and prev_segment[-1] == '"':
|
||||
segments[i] = f"*{segment}*"
|
||||
elif prev_segment and prev_segment[-1] == '*':
|
||||
segments[i] = f"\"{segment}\""
|
||||
elif prev_segment and prev_segment[-1] == "*":
|
||||
segments[i] = f'"{segment}"'
|
||||
elif next_segment and next_segment[0] == '"':
|
||||
segments[i] = f"*{segment}*"
|
||||
elif next_segment and next_segment[0] == '*':
|
||||
segments[i] = f"\"{segment}\""
|
||||
|
||||
elif next_segment and next_segment[0] == "*":
|
||||
segments[i] = f'"{segment}"'
|
||||
|
||||
for i in range(len(segments)):
|
||||
segments[i] = clean_uneven_markers(segments[i], '"')
|
||||
segments[i] = clean_uneven_markers(segments[i], '*')
|
||||
|
||||
segments[i] = clean_uneven_markers(segments[i], "*")
|
||||
|
||||
final = " ".join(segment for segment in segments if segment).strip()
|
||||
final = final.replace('","', '').replace('"."', '')
|
||||
final = final.replace('","', "").replace('"."', "")
|
||||
return final
|
||||
|
||||
|
||||
def clean_uneven_markers(chunk:str, marker:str):
|
||||
|
||||
def clean_uneven_markers(chunk: str, marker: str):
|
||||
# if there is an uneven number of quotes, remove the last one if its
|
||||
# at the end of the chunk. If its in the middle, add a quote to the endc
|
||||
count = chunk.count(marker)
|
||||
|
||||
|
||||
if count % 2 == 1:
|
||||
if chunk.endswith(marker):
|
||||
chunk = chunk[:-1]
|
||||
@@ -946,5 +995,5 @@ def clean_uneven_markers(chunk:str, marker:str):
|
||||
chunk = chunk.replace(marker, "")
|
||||
else:
|
||||
chunk += marker
|
||||
|
||||
return chunk
|
||||
|
||||
return chunk
|
||||
|
||||
@@ -1,31 +1,36 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from talemate.emit import emit
|
||||
import structlog
|
||||
import traceback
|
||||
from typing import Union, Any
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate.prompts import Prompt
|
||||
import talemate.automated_action as automated_action
|
||||
import talemate.instance as instance
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
ANY_CHARACTER = "__any_character__"
|
||||
|
||||
log = structlog.get_logger("talemate")
|
||||
|
||||
|
||||
class CharacterState(BaseModel):
|
||||
snapshot: Union[str, None] = None
|
||||
emotion: Union[str, None] = None
|
||||
|
||||
|
||||
|
||||
class ObjectState(BaseModel):
|
||||
snapshot: Union[str, None] = None
|
||||
|
||||
|
||||
|
||||
class InsertionMode(Enum):
|
||||
sequential = "sequential"
|
||||
conversation_context = "conversation-context"
|
||||
all_context = "all-context"
|
||||
never = "never"
|
||||
|
||||
|
||||
class Reinforcement(BaseModel):
|
||||
question: str
|
||||
answer: Union[str, None] = None
|
||||
@@ -34,11 +39,10 @@ class Reinforcement(BaseModel):
|
||||
character: Union[str, None] = None
|
||||
instructions: Union[str, None] = None
|
||||
insert: str = "sequential"
|
||||
|
||||
|
||||
@property
|
||||
def as_context_line(self) -> str:
|
||||
if self.character:
|
||||
|
||||
if self.question.strip().endswith("?"):
|
||||
return f"{self.character}: {self.question} {self.answer}"
|
||||
else:
|
||||
@@ -46,57 +50,61 @@ class Reinforcement(BaseModel):
|
||||
|
||||
if self.question.strip().endswith("?"):
|
||||
return f"{self.question} {self.answer}"
|
||||
|
||||
|
||||
return f"{self.question}: {self.answer}"
|
||||
|
||||
|
||||
class ManualContext(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
meta: dict[str, Any] = {}
|
||||
|
||||
|
||||
|
||||
class ContextPin(BaseModel):
|
||||
entry_id: str
|
||||
condition: Union[str, None] = None
|
||||
condition_state: bool = False
|
||||
active: bool = False
|
||||
|
||||
|
||||
|
||||
class WorldState(BaseModel):
|
||||
|
||||
# characters in the scene by name
|
||||
characters: dict[str, CharacterState] = {}
|
||||
|
||||
|
||||
# objects in the scene by name
|
||||
items: dict[str, ObjectState] = {}
|
||||
|
||||
|
||||
# location description
|
||||
location: Union[str, None] = None
|
||||
|
||||
|
||||
# reinforcers
|
||||
reinforce: list[Reinforcement] = []
|
||||
|
||||
|
||||
# pins
|
||||
pins: dict[str, ContextPin] = {}
|
||||
|
||||
|
||||
# manual context
|
||||
manual_context: dict[str, ManualContext] = {}
|
||||
|
||||
|
||||
@property
|
||||
def agent(self):
|
||||
return instance.get_agent("world_state")
|
||||
|
||||
|
||||
@property
|
||||
def scene(self):
|
||||
return self.agent.scene
|
||||
|
||||
|
||||
@property
|
||||
def pretty_json(self):
|
||||
return self.model_dump_json(indent=2)
|
||||
|
||||
|
||||
@property
|
||||
def as_list(self):
|
||||
return self.render().as_list
|
||||
|
||||
def filter_reinforcements(self, character:str=ANY_CHARACTER, insert:list[str]=None) -> list[Reinforcement]:
|
||||
|
||||
def filter_reinforcements(
|
||||
self, character: str = ANY_CHARACTER, insert: list[str] = None
|
||||
) -> list[Reinforcement]:
|
||||
"""
|
||||
Returns a filtered list of Reinforcement objects based on character and insert criteria.
|
||||
|
||||
@@ -107,24 +115,23 @@ class WorldState(BaseModel):
|
||||
"""
|
||||
Returns a filtered set of results as list
|
||||
"""
|
||||
|
||||
|
||||
result = []
|
||||
|
||||
|
||||
for reinforcement in self.reinforce:
|
||||
|
||||
if not reinforcement.answer:
|
||||
continue
|
||||
|
||||
|
||||
if character != ANY_CHARACTER and reinforcement.character != character:
|
||||
continue
|
||||
|
||||
|
||||
if insert and reinforcement.insert not in insert:
|
||||
continue
|
||||
|
||||
|
||||
result.append(reinforcement)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the WorldState instance to its initial state by clearing characters, items, and location.
|
||||
@@ -134,8 +141,8 @@ class WorldState(BaseModel):
|
||||
"""
|
||||
self.characters = {}
|
||||
self.items = {}
|
||||
self.location = None
|
||||
|
||||
self.location = None
|
||||
|
||||
def emit(self, status="update"):
|
||||
"""
|
||||
Emits the current world state with the given status.
|
||||
@@ -144,8 +151,8 @@ class WorldState(BaseModel):
|
||||
- status: The status of the world state to emit, which influences the handling of the update event.
|
||||
"""
|
||||
emit("world_state", status=status, data=self.model_dump())
|
||||
|
||||
async def request_update(self, initial_only:bool=False):
|
||||
|
||||
async def request_update(self, initial_only: bool = False):
|
||||
"""
|
||||
Requests an update of the world state from the WorldState agent. If initial_only is true, emits current state without requesting if characters exist.
|
||||
|
||||
@@ -156,85 +163,97 @@ class WorldState(BaseModel):
|
||||
if initial_only and self.characters:
|
||||
self.emit()
|
||||
return
|
||||
|
||||
|
||||
# if auto is true, we need to check if agent has automatic update enabled
|
||||
if initial_only and not self.agent.actions["update_world_state"].enabled:
|
||||
self.emit()
|
||||
return
|
||||
|
||||
self.emit(status="requested")
|
||||
|
||||
|
||||
try:
|
||||
world_state = await self.agent.request_world_state()
|
||||
except Exception as e:
|
||||
self.emit()
|
||||
log.error("world_state.request_update", error=e, traceback=traceback.format_exc())
|
||||
log.error(
|
||||
"world_state.request_update", error=e, traceback=traceback.format_exc()
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
previous_characters = self.characters
|
||||
previous_items = self.items
|
||||
scene = self.agent.scene
|
||||
character_names = scene.character_names
|
||||
self.characters = {}
|
||||
self.items = {}
|
||||
|
||||
|
||||
for character_name, character in world_state.get("characters", {}).items():
|
||||
|
||||
# character name may not always come back exactly as we have
|
||||
# it defined in the scene. We assign the correct name by checking occurences
|
||||
# of both names in each other.
|
||||
|
||||
|
||||
if character_name not in character_names:
|
||||
for _character_name in character_names:
|
||||
if _character_name.lower() in character_name.lower() or character_name.lower() in _character_name.lower():
|
||||
log.debug("world_state adjusting character name", from_name=character_name, to_name=_character_name)
|
||||
if (
|
||||
_character_name.lower() in character_name.lower()
|
||||
or character_name.lower() in _character_name.lower()
|
||||
):
|
||||
log.debug(
|
||||
"world_state adjusting character name",
|
||||
from_name=character_name,
|
||||
to_name=_character_name,
|
||||
)
|
||||
character_name = _character_name
|
||||
break
|
||||
|
||||
|
||||
if not character:
|
||||
continue
|
||||
|
||||
|
||||
# if emotion is not set, see if a previous state exists
|
||||
# and use that emotion
|
||||
|
||||
|
||||
if "emotion" not in character:
|
||||
log.debug("emotion not set", character_name=character_name, character=character, characters=previous_characters)
|
||||
log.debug(
|
||||
"emotion not set",
|
||||
character_name=character_name,
|
||||
character=character,
|
||||
characters=previous_characters,
|
||||
)
|
||||
if character_name in previous_characters:
|
||||
character["emotion"] = previous_characters[character_name].emotion
|
||||
|
||||
|
||||
self.characters[character_name] = CharacterState(**character)
|
||||
log.debug("world_state", character=character)
|
||||
|
||||
|
||||
for item_name, item in world_state.get("items", {}).items():
|
||||
if not item:
|
||||
continue
|
||||
self.items[item_name] = ObjectState(**item)
|
||||
log.debug("world_state", item=item)
|
||||
|
||||
|
||||
|
||||
# deactivate persiting for now
|
||||
# await self.persist()
|
||||
self.emit()
|
||||
|
||||
|
||||
async def persist(self):
|
||||
"""
|
||||
Persists the world state snapshots of characters and items into the memory agent.
|
||||
|
||||
TODO: neeeds re-thinking.
|
||||
|
||||
|
||||
Its better to use state reinforcement to track states, persisting the small world
|
||||
state snapshots most of the time does not have enough context to be useful.
|
||||
|
||||
Arguments:
|
||||
- None
|
||||
"""
|
||||
|
||||
|
||||
memory = instance.get_agent("memory")
|
||||
world_state = instance.get_agent("world_state")
|
||||
|
||||
|
||||
# first we check if any of the characters were refered
|
||||
# to with an alias
|
||||
|
||||
|
||||
states = []
|
||||
scene = self.agent.scene
|
||||
|
||||
@@ -247,10 +266,10 @@ class WorldState(BaseModel):
|
||||
"typ": "world_state",
|
||||
"character": character_name,
|
||||
"ts": scene.ts,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
for item_name in self.items.keys():
|
||||
states.append(
|
||||
{
|
||||
@@ -260,18 +279,17 @@ class WorldState(BaseModel):
|
||||
"typ": "world_state",
|
||||
"item": item_name,
|
||||
"ts": scene.ts,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
log.debug("world_state.persist", states=states)
|
||||
|
||||
|
||||
if not states:
|
||||
return
|
||||
|
||||
await memory.add_many(states)
|
||||
|
||||
|
||||
await memory.add_many(states)
|
||||
|
||||
async def request_update_inline(self):
|
||||
"""
|
||||
Requests an inline update of the world state from the WorldState agent and immediately emits the state.
|
||||
@@ -279,22 +297,21 @@ class WorldState(BaseModel):
|
||||
Arguments:
|
||||
- None
|
||||
"""
|
||||
|
||||
|
||||
self.emit(status="requested")
|
||||
|
||||
|
||||
world_state = await self.agent.request_world_state_inline()
|
||||
|
||||
|
||||
self.emit()
|
||||
|
||||
|
||||
|
||||
async def add_reinforcement(
|
||||
self,
|
||||
question:str,
|
||||
character:str=None,
|
||||
instructions:str=None,
|
||||
interval:int=10,
|
||||
answer:str="",
|
||||
insert:str="sequential",
|
||||
self,
|
||||
question: str,
|
||||
character: str = None,
|
||||
instructions: str = None,
|
||||
interval: int = 10,
|
||||
answer: str = "",
|
||||
insert: str = "sequential",
|
||||
) -> Reinforcement:
|
||||
"""
|
||||
Adds or updates a reinforcement in the world state. If a reinforcement with the same question and character exists, it is updated.
|
||||
@@ -307,50 +324,59 @@ class WorldState(BaseModel):
|
||||
- answer: The answer to the reinforcement question.
|
||||
- insert: The method of inserting the reinforcement into the context.
|
||||
"""
|
||||
|
||||
|
||||
# if reinforcement already exists, update it
|
||||
|
||||
|
||||
idx, reinforcement = await self.find_reinforcement(question, character)
|
||||
|
||||
|
||||
if reinforcement:
|
||||
|
||||
# update the reinforcement object
|
||||
|
||||
|
||||
reinforcement.instructions = instructions
|
||||
reinforcement.interval = interval
|
||||
reinforcement.answer = answer
|
||||
|
||||
|
||||
old_insert_method = reinforcement.insert
|
||||
|
||||
|
||||
reinforcement.insert = insert
|
||||
|
||||
|
||||
# find the reinforcement message i nthe scene history and update the answer
|
||||
if old_insert_method == "sequential":
|
||||
message = self.agent.scene.find_message(typ="reinforcement", source=f"{question}:{character if character else ''}")
|
||||
|
||||
message = self.agent.scene.find_message(
|
||||
typ="reinforcement",
|
||||
source=f"{question}:{character if character else ''}",
|
||||
)
|
||||
|
||||
if old_insert_method != insert and message:
|
||||
|
||||
# if it used to be sequential we need to remove its ReinforcmentMessage
|
||||
# from the scene history
|
||||
|
||||
|
||||
self.scene.pop_history(typ="reinforcement", source=message.source)
|
||||
|
||||
|
||||
elif message:
|
||||
message.message = answer
|
||||
elif insert == "sequential":
|
||||
# if it used to be something else and is now sequential, we need to run the state
|
||||
# next loop
|
||||
reinforcement.due = 0
|
||||
|
||||
|
||||
# update the character detail if character name is specified
|
||||
if character:
|
||||
character = self.agent.scene.get_character(character)
|
||||
await character.set_detail(question, answer)
|
||||
|
||||
|
||||
return reinforcement
|
||||
|
||||
log.debug("world_state.add_reinforcement", question=question, character=character, instructions=instructions, interval=interval, answer=answer, insert=insert)
|
||||
|
||||
|
||||
log.debug(
|
||||
"world_state.add_reinforcement",
|
||||
question=question,
|
||||
character=character,
|
||||
instructions=instructions,
|
||||
interval=interval,
|
||||
answer=answer,
|
||||
insert=insert,
|
||||
)
|
||||
|
||||
reinforcement = Reinforcement(
|
||||
question=question,
|
||||
character=character,
|
||||
@@ -359,12 +385,12 @@ class WorldState(BaseModel):
|
||||
answer=answer,
|
||||
insert=insert,
|
||||
)
|
||||
|
||||
|
||||
self.reinforce.append(reinforcement)
|
||||
|
||||
|
||||
return reinforcement
|
||||
|
||||
async def find_reinforcement(self, question:str, character:str=None):
|
||||
|
||||
async def find_reinforcement(self, question: str, character: str = None):
|
||||
"""
|
||||
Finds a reinforcement based on the question and character provided. Returns the index in the list and the reinforcement object.
|
||||
|
||||
@@ -373,11 +399,14 @@ class WorldState(BaseModel):
|
||||
- character: The character to whom the reinforcement is linked. Use None for global reinforcements.
|
||||
"""
|
||||
for idx, reinforcement in enumerate(self.reinforce):
|
||||
if reinforcement.question == question and reinforcement.character == character:
|
||||
if (
|
||||
reinforcement.question == question
|
||||
and reinforcement.character == character
|
||||
):
|
||||
return idx, reinforcement
|
||||
return None, None
|
||||
|
||||
def reinforcements_for_character(self, character:str):
|
||||
|
||||
def reinforcements_for_character(self, character: str):
|
||||
"""
|
||||
Returns a dictionary of reinforcements specifically for a given character.
|
||||
|
||||
@@ -385,13 +414,13 @@ class WorldState(BaseModel):
|
||||
- character: The name of the character for whom reinforcements should be retrieved.
|
||||
"""
|
||||
reinforcements = {}
|
||||
|
||||
|
||||
for reinforcement in self.reinforce:
|
||||
if reinforcement.character == character:
|
||||
reinforcements[reinforcement.question] = reinforcement
|
||||
|
||||
|
||||
return reinforcements
|
||||
|
||||
|
||||
def reinforcements_for_world(self):
|
||||
"""
|
||||
Returns a dictionary of global reinforcements not linked to any specific character.
|
||||
@@ -400,56 +429,57 @@ class WorldState(BaseModel):
|
||||
- None
|
||||
"""
|
||||
reinforcements = {}
|
||||
|
||||
|
||||
for reinforcement in self.reinforce:
|
||||
if not reinforcement.character:
|
||||
reinforcements[reinforcement.question] = reinforcement
|
||||
|
||||
|
||||
return reinforcements
|
||||
|
||||
async def remove_reinforcement(self, idx:int):
|
||||
|
||||
async def remove_reinforcement(self, idx: int):
|
||||
"""
|
||||
Removes a reinforcement from the world state.
|
||||
|
||||
|
||||
Arguments:
|
||||
- idx: The index of the reinforcement to remove.
|
||||
"""
|
||||
|
||||
|
||||
# find all instances of the reinforcement in the scene history
|
||||
# and remove them
|
||||
source=f"{self.reinforce[idx].question}:{self.reinforce[idx].character if self.reinforce[idx].character else ''}"
|
||||
source = f"{self.reinforce[idx].question}:{self.reinforce[idx].character if self.reinforce[idx].character else ''}"
|
||||
self.agent.scene.pop_history(typ="reinforcement", source=source, all=True)
|
||||
|
||||
|
||||
self.reinforce.pop(idx)
|
||||
|
||||
|
||||
def render(self):
|
||||
|
||||
"""
|
||||
Renders the world state as a string.
|
||||
"""
|
||||
|
||||
|
||||
return Prompt.get(
|
||||
"world_state.render",
|
||||
vars={
|
||||
"characters": self.characters,
|
||||
"items": self.items,
|
||||
"location": self.location,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def commit_to_memory(self, memory_agent):
|
||||
await memory_agent.add_many([
|
||||
manual_context.model_dump() for manual_context in self.manual_context.values()
|
||||
])
|
||||
|
||||
await memory_agent.add_many(
|
||||
[
|
||||
manual_context.model_dump()
|
||||
for manual_context in self.manual_context.values()
|
||||
]
|
||||
)
|
||||
|
||||
def manual_context_for_world(self) -> dict[str, ManualContext]:
|
||||
|
||||
"""
|
||||
Returns all manual context entries where meta["typ"] == "world_state"
|
||||
"""
|
||||
|
||||
|
||||
return {
|
||||
manual_context.id: manual_context
|
||||
for manual_context in self.manual_context.values()
|
||||
if manual_context.meta.get("typ") == "world_state"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,65 +1,75 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
|
||||
from talemate.config import StateReinforcementTemplate, WorldStateTemplates, save_config
|
||||
from talemate.instance import get_agent
|
||||
from talemate.config import WorldStateTemplates, StateReinforcementTemplate, save_config
|
||||
from talemate.world_state import Reinforcement, ManualContext, ContextPin, InsertionMode
|
||||
from talemate.world_state import ContextPin, InsertionMode, ManualContext, Reinforcement
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
log = structlog.get_logger("talemate.server.world_state_manager")
|
||||
|
||||
|
||||
class CharacterSelect(pydantic.BaseModel):
|
||||
name: str
|
||||
active: bool = True
|
||||
is_player: bool = False
|
||||
|
||||
|
||||
|
||||
class ContextDBEntry(pydantic.BaseModel):
|
||||
text: str
|
||||
meta: dict
|
||||
id: Any
|
||||
|
||||
|
||||
|
||||
class ContextDB(pydantic.BaseModel):
|
||||
entries: list[ContextDBEntry] = []
|
||||
|
||||
|
||||
|
||||
class CharacterDetails(pydantic.BaseModel):
|
||||
name: str
|
||||
active: bool = True
|
||||
is_player: bool = False
|
||||
description: str = ""
|
||||
base_attributes: dict[str,str] = {}
|
||||
details: dict[str,str] = {}
|
||||
base_attributes: dict[str, str] = {}
|
||||
details: dict[str, str] = {}
|
||||
reinforcements: dict[str, Reinforcement] = {}
|
||||
|
||||
|
||||
class World(pydantic.BaseModel):
|
||||
entries: dict[str, ManualContext] = {}
|
||||
reinforcements: dict[str, Reinforcement] = {}
|
||||
|
||||
|
||||
|
||||
class CharacterList(pydantic.BaseModel):
|
||||
characters: dict[str, CharacterSelect] = {}
|
||||
|
||||
|
||||
|
||||
class HistoryEntry(pydantic.BaseModel):
|
||||
text: str
|
||||
start: int = None
|
||||
end: int = None
|
||||
ts: str = None
|
||||
ts: str = None
|
||||
|
||||
|
||||
class History(pydantic.BaseModel):
|
||||
history: list[HistoryEntry] = []
|
||||
|
||||
|
||||
|
||||
class AnnotatedContextPin(pydantic.BaseModel):
|
||||
pin: ContextPin
|
||||
text: str
|
||||
time_aware_text: str
|
||||
|
||||
|
||||
|
||||
class ContextPins(pydantic.BaseModel):
|
||||
pins: dict[str, AnnotatedContextPin] = []
|
||||
|
||||
|
||||
|
||||
class WorldStateManager:
|
||||
|
||||
@property
|
||||
def memory_agent(self):
|
||||
"""
|
||||
@@ -69,8 +79,8 @@ class WorldStateManager:
|
||||
The memory agent instance responsible for managing memory-related operations.
|
||||
"""
|
||||
return get_agent("memory")
|
||||
|
||||
def __init__(self, scene:'Scene'):
|
||||
|
||||
def __init__(self, scene: "Scene"):
|
||||
"""
|
||||
Initializes the WorldStateManager with a given scene.
|
||||
|
||||
@@ -79,7 +89,7 @@ class WorldStateManager:
|
||||
"""
|
||||
self.scene = scene
|
||||
self.world_state = scene.world_state
|
||||
|
||||
|
||||
async def get_character_list(self) -> CharacterList:
|
||||
"""
|
||||
Retrieves a list of characters from the current scene.
|
||||
@@ -87,18 +97,22 @@ class WorldStateManager:
|
||||
Returns:
|
||||
A CharacterList object containing the characters with their select properties from the scene.
|
||||
"""
|
||||
|
||||
|
||||
characters = CharacterList()
|
||||
|
||||
|
||||
for character in self.scene.get_characters():
|
||||
characters.characters[character.name] = CharacterSelect(name=character.name, active=True, is_player=character.is_player)
|
||||
|
||||
characters.characters[character.name] = CharacterSelect(
|
||||
name=character.name, active=True, is_player=character.is_player
|
||||
)
|
||||
|
||||
for character in self.scene.inactive_characters.values():
|
||||
characters.characters[character.name] = CharacterSelect(name=character.name, active=False, is_player=character.is_player)
|
||||
|
||||
characters.characters[character.name] = CharacterSelect(
|
||||
name=character.name, active=False, is_player=character.is_player
|
||||
)
|
||||
|
||||
return characters
|
||||
|
||||
async def get_character_details(self, character_name:str) -> CharacterDetails:
|
||||
|
||||
async def get_character_details(self, character_name: str) -> CharacterDetails:
|
||||
"""
|
||||
Fetches and returns the details for a specific character by name.
|
||||
|
||||
@@ -108,21 +122,28 @@ class WorldStateManager:
|
||||
Returns:
|
||||
A CharacterDetails object containing the character's details, attributes, and reinforcements.
|
||||
"""
|
||||
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
details = CharacterDetails(name=character.name, active=True, description=character.description, is_player=character.is_player)
|
||||
|
||||
|
||||
details = CharacterDetails(
|
||||
name=character.name,
|
||||
active=True,
|
||||
description=character.description,
|
||||
is_player=character.is_player,
|
||||
)
|
||||
|
||||
for key, value in character.base_attributes.items():
|
||||
details.base_attributes[key] = value
|
||||
|
||||
|
||||
for key, value in character.details.items():
|
||||
details.details[key] = value
|
||||
|
||||
details.reinforcements = self.world_state.reinforcements_for_character(character_name)
|
||||
|
||||
|
||||
details.reinforcements = self.world_state.reinforcements_for_character(
|
||||
character_name
|
||||
)
|
||||
|
||||
return details
|
||||
|
||||
|
||||
async def get_world(self) -> World:
|
||||
"""
|
||||
Retrieves the current state of the world, including entries and reinforcements.
|
||||
@@ -132,10 +153,12 @@ class WorldStateManager:
|
||||
"""
|
||||
return World(
|
||||
entries=self.world_state.manual_context_for_world(),
|
||||
reinforcements=self.world_state.reinforcements_for_world()
|
||||
reinforcements=self.world_state.reinforcements_for_world(),
|
||||
)
|
||||
|
||||
async def get_context_db_entries(self, query:str, limit:int=20, **meta) -> ContextDB:
|
||||
|
||||
async def get_context_db_entries(
|
||||
self, query: str, limit: int = 20, **meta
|
||||
) -> ContextDB:
|
||||
"""
|
||||
Retrieves entries from the context database based on a query and metadata.
|
||||
|
||||
@@ -147,22 +170,24 @@ class WorldStateManager:
|
||||
Returns:
|
||||
A ContextDB object containing the found entries.
|
||||
"""
|
||||
|
||||
|
||||
if query.startswith("id:"):
|
||||
_entries = await self.memory_agent.get_document(id=query[3:])
|
||||
_entries = list(_entries.values())
|
||||
else:
|
||||
_entries = await self.memory_agent.multi_query([query], iterate=limit, max_tokens=9999999, **meta)
|
||||
|
||||
_entries = await self.memory_agent.multi_query(
|
||||
[query], iterate=limit, max_tokens=9999999, **meta
|
||||
)
|
||||
|
||||
entries = []
|
||||
for entry in _entries:
|
||||
entries.append(ContextDBEntry(text=entry.raw, meta=entry.meta, id=entry.id))
|
||||
|
||||
|
||||
context_db = ContextDB(entries=entries)
|
||||
|
||||
|
||||
return context_db
|
||||
|
||||
async def get_pins(self, active:bool=None) -> ContextPins:
|
||||
|
||||
async def get_pins(self, active: bool = None) -> ContextPins:
|
||||
"""
|
||||
Retrieves context pins that meet the specified activity condition.
|
||||
|
||||
@@ -172,31 +197,36 @@ class WorldStateManager:
|
||||
Returns:
|
||||
A ContextPins object containing the matching annotated context pins.
|
||||
"""
|
||||
|
||||
|
||||
pins = self.world_state.pins
|
||||
|
||||
candidates = [pin for pin in pins.values() if pin.active == active or active is None]
|
||||
|
||||
|
||||
candidates = [
|
||||
pin for pin in pins.values() if pin.active == active or active is None
|
||||
]
|
||||
|
||||
_ids = [pin.entry_id for pin in candidates]
|
||||
_pins = {}
|
||||
documents = await self.memory_agent.get_document(id=_ids)
|
||||
|
||||
|
||||
for pin in sorted(candidates, key=lambda x: x.active, reverse=True):
|
||||
|
||||
if pin.entry_id not in documents:
|
||||
text = ""
|
||||
time_aware_text = ""
|
||||
else:
|
||||
text = documents[pin.entry_id].raw
|
||||
time_aware_text = str(documents[pin.entry_id])
|
||||
|
||||
annotated_pin = AnnotatedContextPin(pin=pin, text=text, time_aware_text=time_aware_text)
|
||||
|
||||
|
||||
annotated_pin = AnnotatedContextPin(
|
||||
pin=pin, text=text, time_aware_text=time_aware_text
|
||||
)
|
||||
|
||||
_pins[pin.entry_id] = annotated_pin
|
||||
|
||||
|
||||
return ContextPins(pins=_pins)
|
||||
|
||||
async def update_character_attribute(self, character_name:str, attribute:str, value:str):
|
||||
|
||||
async def update_character_attribute(
|
||||
self, character_name: str, attribute: str, value: str
|
||||
):
|
||||
"""
|
||||
Updates the attribute of a character to a new value.
|
||||
|
||||
@@ -207,8 +237,10 @@ class WorldStateManager:
|
||||
"""
|
||||
character = self.scene.get_character(character_name)
|
||||
await character.set_base_attribute(attribute, value)
|
||||
|
||||
async def update_character_detail(self, character_name:str, detail:str, value:str):
|
||||
|
||||
async def update_character_detail(
|
||||
self, character_name: str, detail: str, value: str
|
||||
):
|
||||
"""
|
||||
Updates a specific detail of a character to a new value.
|
||||
|
||||
@@ -219,8 +251,8 @@ class WorldStateManager:
|
||||
"""
|
||||
character = self.scene.get_character(character_name)
|
||||
await character.set_detail(detail, value)
|
||||
|
||||
async def update_character_description(self, character_name:str, description:str):
|
||||
|
||||
async def update_character_description(self, character_name: str, description: str):
|
||||
"""
|
||||
Updates the description of a character to a new value.
|
||||
|
||||
@@ -230,16 +262,16 @@ class WorldStateManager:
|
||||
"""
|
||||
character = self.scene.get_character(character_name)
|
||||
await character.set_description(description)
|
||||
|
||||
|
||||
async def add_detail_reinforcement(
|
||||
self,
|
||||
character_name:str,
|
||||
question:str,
|
||||
instructions:str=None,
|
||||
interval:int=10,
|
||||
answer:str="",
|
||||
insert:str="sequential",
|
||||
run_immediately:bool=False
|
||||
self,
|
||||
character_name: str,
|
||||
question: str,
|
||||
instructions: str = None,
|
||||
interval: int = 10,
|
||||
answer: str = "",
|
||||
insert: str = "sequential",
|
||||
run_immediately: bool = False,
|
||||
) -> Reinforcement:
|
||||
"""
|
||||
Adds a detail reinforcement for a character with specified parameters.
|
||||
@@ -262,16 +294,18 @@ class WorldStateManager:
|
||||
reinforcement = await self.world_state.add_reinforcement(
|
||||
question, character_name, instructions, interval, answer, insert
|
||||
)
|
||||
|
||||
|
||||
if run_immediately:
|
||||
await world_state_agent.update_reinforcement(question, character_name)
|
||||
else:
|
||||
# if not running immediately, we need to emit the world state manually
|
||||
self.world_state.emit()
|
||||
|
||||
|
||||
return reinforcement
|
||||
|
||||
async def run_detail_reinforcement(self, character_name:str, question:str, reset:bool=False):
|
||||
|
||||
async def run_detail_reinforcement(
|
||||
self, character_name: str, question: str, reset: bool = False
|
||||
):
|
||||
"""
|
||||
Executes the detail reinforcement for a specific character and question.
|
||||
|
||||
@@ -280,9 +314,11 @@ class WorldStateManager:
|
||||
question: The query/question that the reinforcement corresponds to.
|
||||
"""
|
||||
world_state_agent = get_agent("world_state")
|
||||
await world_state_agent.update_reinforcement(question, character_name, reset=reset)
|
||||
|
||||
async def delete_detail_reinforcement(self, character_name:str, question:str):
|
||||
await world_state_agent.update_reinforcement(
|
||||
question, character_name, reset=reset
|
||||
)
|
||||
|
||||
async def delete_detail_reinforcement(self, character_name: str, question: str):
|
||||
"""
|
||||
Deletes a detail reinforcement for a specified character and question.
|
||||
|
||||
@@ -290,13 +326,15 @@ class WorldStateManager:
|
||||
character_name: The name of the character whose reinforcement is to be deleted.
|
||||
question: The query/question of the reinforcement to be deleted.
|
||||
"""
|
||||
|
||||
idx, reinforcement = await self.world_state.find_reinforcement(question, character_name)
|
||||
|
||||
idx, reinforcement = await self.world_state.find_reinforcement(
|
||||
question, character_name
|
||||
)
|
||||
if idx is not None:
|
||||
await self.world_state.remove_reinforcement(idx)
|
||||
self.world_state.emit()
|
||||
|
||||
async def save_world_entry(self, entry_id:str, text:str, meta:dict):
|
||||
|
||||
async def save_world_entry(self, entry_id: str, text: str, meta: dict):
|
||||
"""
|
||||
Saves a manual world entry with specified text and metadata.
|
||||
|
||||
@@ -308,8 +346,8 @@ class WorldStateManager:
|
||||
meta["source"] = "manual"
|
||||
meta["typ"] = "world_state"
|
||||
await self.update_context_db_entry(entry_id, text, meta)
|
||||
|
||||
async def update_context_db_entry(self, entry_id:str, text:str, meta:dict):
|
||||
|
||||
async def update_context_db_entry(self, entry_id: str, text: str, meta: dict):
|
||||
"""
|
||||
Updates an entry in the context database with new text and metadata.
|
||||
|
||||
@@ -318,47 +356,42 @@ class WorldStateManager:
|
||||
text: The new text content for the world entry.
|
||||
meta: A dictionary containing updated metadata for the world entry.
|
||||
"""
|
||||
|
||||
|
||||
if meta.get("source") == "manual":
|
||||
# manual context needs to be updated in the world state
|
||||
self.world_state.manual_context[entry_id] = ManualContext(
|
||||
text=text,
|
||||
meta=meta,
|
||||
id=entry_id
|
||||
text=text, meta=meta, id=entry_id
|
||||
)
|
||||
elif meta.get("typ") == "details":
|
||||
# character detail needs to be mirrored to the
|
||||
# character detail needs to be mirrored to the
|
||||
# character object in the scene
|
||||
character_name = meta.get("character")
|
||||
character = self.scene.get_character(character_name)
|
||||
character.details[meta.get("detail")] = text
|
||||
|
||||
|
||||
await self.memory_agent.add_many([
|
||||
{
|
||||
"id": entry_id,
|
||||
"text": text,
|
||||
"meta": meta
|
||||
}
|
||||
])
|
||||
|
||||
async def delete_context_db_entry(self, entry_id:str):
|
||||
|
||||
await self.memory_agent.add_many([{"id": entry_id, "text": text, "meta": meta}])
|
||||
|
||||
async def delete_context_db_entry(self, entry_id: str):
|
||||
"""
|
||||
Deletes a specific entry from the context database using its identifier.
|
||||
|
||||
Arguments:
|
||||
entry_id: The identifier of the world entry to be deleted.
|
||||
"""
|
||||
await self.memory_agent.delete({
|
||||
"ids": entry_id
|
||||
})
|
||||
|
||||
await self.memory_agent.delete({"ids": entry_id})
|
||||
|
||||
if entry_id in self.world_state.manual_context:
|
||||
del self.world_state.manual_context[entry_id]
|
||||
|
||||
|
||||
await self.remove_pin(entry_id)
|
||||
|
||||
async def set_pin(self, entry_id:str, condition:str=None, condition_state:bool=False, active:bool=False):
|
||||
|
||||
async def set_pin(
|
||||
self,
|
||||
entry_id: str,
|
||||
condition: str = None,
|
||||
condition_state: bool = False,
|
||||
active: bool = False,
|
||||
):
|
||||
"""
|
||||
Creates or updates a pin on a context entry with conditional activation.
|
||||
|
||||
@@ -368,33 +401,32 @@ class WorldStateManager:
|
||||
condition_state: The boolean state that enables the pin; defaults to False.
|
||||
active: A flag indicating whether the pin should be active; defaults to False.
|
||||
"""
|
||||
|
||||
|
||||
if not condition:
|
||||
condition = None
|
||||
condition_state = False
|
||||
|
||||
|
||||
pin = ContextPin(
|
||||
entry_id=entry_id,
|
||||
condition=condition,
|
||||
condition_state=condition_state,
|
||||
active=active
|
||||
active=active,
|
||||
)
|
||||
|
||||
|
||||
self.world_state.pins[entry_id] = pin
|
||||
|
||||
|
||||
async def remove_all_empty_pins(self):
|
||||
|
||||
"""
|
||||
Removes all pins that come back with empty `text` attributes from get_pins.
|
||||
"""
|
||||
|
||||
|
||||
pins = await self.get_pins()
|
||||
|
||||
|
||||
for pin in pins.pins.values():
|
||||
if not pin.text:
|
||||
await self.remove_pin(pin.pin.entry_id)
|
||||
|
||||
async def remove_pin(self, entry_id:str):
|
||||
|
||||
async def remove_pin(self, entry_id: str):
|
||||
"""
|
||||
Removes an existing pin from a context entry using its identifier.
|
||||
|
||||
@@ -403,8 +435,7 @@ class WorldStateManager:
|
||||
"""
|
||||
if entry_id in self.world_state.pins:
|
||||
del self.world_state.pins[entry_id]
|
||||
|
||||
|
||||
|
||||
async def get_templates(self) -> WorldStateTemplates:
|
||||
"""
|
||||
Retrieves the current world state templates from scene configuration.
|
||||
@@ -415,9 +446,8 @@ class WorldStateManager:
|
||||
templates = self.scene.config["game"]["world_state"]["templates"]
|
||||
world_state_templates = WorldStateTemplates(**templates)
|
||||
return world_state_templates
|
||||
|
||||
|
||||
async def save_template(self, template:StateReinforcementTemplate):
|
||||
|
||||
async def save_template(self, template: StateReinforcementTemplate):
|
||||
"""
|
||||
Saves a state reinforcement template to the scene configuration.
|
||||
|
||||
@@ -428,17 +458,19 @@ class WorldStateManager:
|
||||
If the template is set to auto-create, it will be applied immediately.
|
||||
"""
|
||||
config = self.scene.config
|
||||
|
||||
|
||||
template_type = template.type
|
||||
|
||||
config["game"]["world_state"]["templates"][template_type][template.name] = template.model_dump()
|
||||
|
||||
|
||||
config["game"]["world_state"]["templates"][template_type][
|
||||
template.name
|
||||
] = template.model_dump()
|
||||
|
||||
save_config(self.scene.config)
|
||||
|
||||
|
||||
if template.auto_create:
|
||||
await self.auto_apply_template(template)
|
||||
|
||||
async def remove_template(self, template_type:str, template_name:str):
|
||||
|
||||
async def remove_template(self, template_type: str, template_name: str):
|
||||
"""
|
||||
Removes a specific state reinforcement template from scene configuration.
|
||||
|
||||
@@ -450,14 +482,18 @@ class WorldStateManager:
|
||||
If the specified template is not found, logs a warning.
|
||||
"""
|
||||
config = self.scene.config
|
||||
|
||||
|
||||
try:
|
||||
del config["game"]["world_state"]["templates"][template_type][template_name]
|
||||
save_config(self.scene.config)
|
||||
except KeyError:
|
||||
log.warning("world state template not found", template_type=template_type, template_name=template_name)
|
||||
log.warning(
|
||||
"world state template not found",
|
||||
template_type=template_type,
|
||||
template_name=template_name,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
async def apply_all_auto_create_templates(self):
|
||||
"""
|
||||
Applies all auto-create state reinforcement templates.
|
||||
@@ -467,18 +503,18 @@ class WorldStateManager:
|
||||
"""
|
||||
templates = self.scene.config["game"]["world_state"]["templates"]
|
||||
world_state_templates = WorldStateTemplates(**templates)
|
||||
|
||||
|
||||
candidates = []
|
||||
|
||||
|
||||
for template in world_state_templates.state_reinforcement.values():
|
||||
if template.auto_create:
|
||||
candidates.append(template)
|
||||
|
||||
|
||||
for template in candidates:
|
||||
log.info("applying template", template=template)
|
||||
await self.auto_apply_template(template)
|
||||
|
||||
async def auto_apply_template(self, template:StateReinforcementTemplate):
|
||||
|
||||
async def auto_apply_template(self, template: StateReinforcementTemplate):
|
||||
"""
|
||||
Automatically applies a state reinforcement template based on its type.
|
||||
|
||||
@@ -490,8 +526,10 @@ class WorldStateManager:
|
||||
"""
|
||||
fn = getattr(self, f"auto_apply_template_{template.type}")
|
||||
await fn(template)
|
||||
|
||||
async def auto_apply_template_state_reinforcement(self, template:StateReinforcementTemplate):
|
||||
|
||||
async def auto_apply_template_state_reinforcement(
|
||||
self, template: StateReinforcementTemplate
|
||||
):
|
||||
"""
|
||||
Applies a state reinforcement template to characters based on the template's state type.
|
||||
|
||||
@@ -501,21 +539,27 @@ class WorldStateManager:
|
||||
Note:
|
||||
The characters to apply the template to are determined by the state_type in the template.
|
||||
"""
|
||||
|
||||
|
||||
characters = []
|
||||
|
||||
|
||||
if template.state_type == "npc":
|
||||
characters = [character.name for character in self.scene.get_npc_characters()]
|
||||
characters = [
|
||||
character.name for character in self.scene.get_npc_characters()
|
||||
]
|
||||
elif template.state_type == "character":
|
||||
characters = [character.name for character in self.scene.get_characters()]
|
||||
elif template.state_type == "player":
|
||||
characters = [self.scene.get_player_character().name]
|
||||
|
||||
|
||||
for character_name in characters:
|
||||
await self.apply_template_state_reinforcement(template, character_name)
|
||||
|
||||
|
||||
async def apply_template_state_reinforcement(self, template:StateReinforcementTemplate, character_name:str=None, run_immediately:bool=False) -> Reinforcement:
|
||||
|
||||
async def apply_template_state_reinforcement(
|
||||
self,
|
||||
template: StateReinforcementTemplate,
|
||||
character_name: str = None,
|
||||
run_immediately: bool = False,
|
||||
) -> Reinforcement:
|
||||
"""
|
||||
Applies a state reinforcement template to a specific character, if provided.
|
||||
|
||||
@@ -530,22 +574,30 @@ class WorldStateManager:
|
||||
Raises:
|
||||
ValueError: If a character name is required but not provided.
|
||||
"""
|
||||
|
||||
|
||||
if not character_name and template.state_type in ["npc", "character", "player"]:
|
||||
raise ValueError("Character name required for this template type.")
|
||||
|
||||
|
||||
player_name = self.scene.get_player_character().name
|
||||
|
||||
formatted_query = template.query.format(character_name=character_name, player_name=player_name)
|
||||
formatted_instructions = template.instructions.format(character_name=character_name, player_name=player_name) if template.instructions else None
|
||||
|
||||
|
||||
formatted_query = template.query.format(
|
||||
character_name=character_name, player_name=player_name
|
||||
)
|
||||
formatted_instructions = (
|
||||
template.instructions.format(
|
||||
character_name=character_name, player_name=player_name
|
||||
)
|
||||
if template.instructions
|
||||
else None
|
||||
)
|
||||
|
||||
if character_name:
|
||||
details = await self.get_character_details(character_name)
|
||||
|
||||
|
||||
# if reinforcement already exists, skip
|
||||
if formatted_query in details.reinforcements:
|
||||
return None
|
||||
|
||||
|
||||
return await self.add_detail_reinforcement(
|
||||
character_name,
|
||||
formatted_query,
|
||||
@@ -553,4 +605,4 @@ class WorldStateManager:
|
||||
template.interval,
|
||||
insert=template.insert,
|
||||
run_immediately=run_immediately,
|
||||
)
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user