* precommit

* linting

* add linting to workflow

* ruff.toml added
This commit is contained in:
veguAI
2025-06-29 19:51:08 +03:00
committed by GitHub
parent 9eb4c48d79
commit fb2fa31f13
206 changed files with 11468 additions and 9720 deletions

View File

@@ -42,6 +42,11 @@ jobs:
source .venv/bin/activate
uv pip install -e ".[dev]"
- name: Run linting
run: |
source .venv/bin/activate
uv run pre-commit run --all-files
- name: Setup configuration file
run: |
cp config.example.yaml config.yaml

16
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,16 @@
fail_fast: false
exclude: |
(?x)^(
tests/data/.*
|install-utils/.*
)$
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.12.1
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format

View File

@@ -1,60 +1,63 @@
import os
import re
import subprocess
from pathlib import Path
import argparse
def find_image_references(md_file):
"""Find all image references in a markdown file."""
with open(md_file, 'r', encoding='utf-8') as f:
with open(md_file, "r", encoding="utf-8") as f:
content = f.read()
pattern = r'!\[.*?\]\((.*?)\)'
pattern = r"!\[.*?\]\((.*?)\)"
matches = re.findall(pattern, content)
cleaned_paths = []
for match in matches:
path = match.lstrip('/')
if 'img/' in path:
path = path[path.index('img/') + 4:]
path = match.lstrip("/")
if "img/" in path:
path = path[path.index("img/") + 4 :]
# Only keep references to versioned images
parts = os.path.normpath(path).split(os.sep)
if len(parts) >= 2 and parts[0].replace('.', '').isdigit():
if len(parts) >= 2 and parts[0].replace(".", "").isdigit():
cleaned_paths.append(path)
return cleaned_paths
def scan_markdown_files(docs_dir):
"""Recursively scan all markdown files in the docs directory."""
md_files = []
for root, _, files in os.walk(docs_dir):
for file in files:
if file.endswith('.md'):
if file.endswith(".md"):
md_files.append(os.path.join(root, file))
return md_files
def find_all_images(img_dir):
"""Find all image files in version subdirectories."""
image_files = []
for root, _, files in os.walk(img_dir):
# Get the relative path from img_dir to current directory
rel_dir = os.path.relpath(root, img_dir)
# Skip if we're in the root img directory
if rel_dir == '.':
if rel_dir == ".":
continue
# Check if the immediate parent directory is a version number
parent_dir = rel_dir.split(os.sep)[0]
if not parent_dir.replace('.', '').isdigit():
if not parent_dir.replace(".", "").isdigit():
continue
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.svg')):
if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".svg")):
rel_path = os.path.relpath(os.path.join(root, file), img_dir)
image_files.append(rel_path)
return image_files
def grep_check_image(docs_dir, image_path):
"""
Check if versioned image is referenced anywhere using grep.
@@ -65,33 +68,46 @@ def grep_check_image(docs_dir, image_path):
parts = os.path.normpath(image_path).split(os.sep)
version = parts[0] # e.g., "0.29.0"
filename = parts[-1] # e.g., "world-state-suggestions-2.png"
# For versioned images, require both version and filename to match
version_pattern = f"{version}.*{filename}"
try:
result = subprocess.run(
['grep', '-r', '-l', version_pattern, docs_dir],
["grep", "-r", "-l", version_pattern, docs_dir],
capture_output=True,
text=True
text=True,
)
if result.stdout.strip():
print(f"Found reference to {image_path} with version pattern: {version_pattern}")
print(
f"Found reference to {image_path} with version pattern: {version_pattern}"
)
return True
except subprocess.CalledProcessError:
pass
except Exception as e:
print(f"Error during grep check for {image_path}: {e}")
return False
def main():
parser = argparse.ArgumentParser(description='Find and optionally delete unused versioned images in MkDocs project')
parser.add_argument('--docs-dir', type=str, required=True, help='Path to the docs directory')
parser.add_argument('--img-dir', type=str, required=True, help='Path to the images directory')
parser.add_argument('--delete', action='store_true', help='Delete unused images')
parser.add_argument('--verbose', action='store_true', help='Show all found references and files')
parser.add_argument('--skip-grep', action='store_true', help='Skip the additional grep validation')
parser = argparse.ArgumentParser(
description="Find and optionally delete unused versioned images in MkDocs project"
)
parser.add_argument(
"--docs-dir", type=str, required=True, help="Path to the docs directory"
)
parser.add_argument(
"--img-dir", type=str, required=True, help="Path to the images directory"
)
parser.add_argument("--delete", action="store_true", help="Delete unused images")
parser.add_argument(
"--verbose", action="store_true", help="Show all found references and files"
)
parser.add_argument(
"--skip-grep", action="store_true", help="Skip the additional grep validation"
)
args = parser.parse_args()
# Convert paths to absolute paths
@@ -118,7 +134,7 @@ def main():
print("\nAll versioned image references found in markdown:")
for img in sorted(used_images):
print(f"- {img}")
print("\nAll versioned images in directory:")
for img in sorted(all_images):
print(f"- {img}")
@@ -133,9 +149,11 @@ def main():
for img in unused_images:
if not grep_check_image(docs_dir, img):
actually_unused.add(img)
if len(actually_unused) != len(unused_images):
print(f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!")
print(
f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!"
)
unused_images = actually_unused
# Report findings
@@ -148,7 +166,7 @@ def main():
print("\nUnused versioned images:")
for img in sorted(unused_images):
print(f"- {img}")
if args.delete:
print("\nDeleting unused versioned images...")
for img in unused_images:
@@ -162,5 +180,6 @@ def main():
else:
print("\nNo unused versioned images found!")
if __name__ == "__main__":
main()
main()

View File

@@ -4,12 +4,12 @@ from talemate.events import GameLoopEvent
import talemate.emit.async_signals
from talemate.emit import emit
@register()
class TestAgent(Agent):
agent_type = "test"
verbose_name = "Test"
def __init__(self, client):
self.client = client
self.is_enabled = True
@@ -20,7 +20,7 @@ class TestAgent(Agent):
description="Test",
),
}
@property
def enabled(self):
return self.is_enabled
@@ -36,7 +36,7 @@ class TestAgent(Agent):
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called on the beginning of every game loop
@@ -45,4 +45,8 @@ class TestAgent(Agent):
if not self.enabled:
return
emit("status", status="info", message="Annoying you with a test message every game loop.")
emit(
"status",
status="info",
message="Annoying you with a test message every game loop.",
)

View File

@@ -19,14 +19,17 @@ from talemate.config import Client as BaseClientConfig
log = structlog.get_logger("talemate.client.runpod_vllm")
class Defaults(pydantic.BaseModel):
max_token_length: int = 4096
model: str = ""
runpod_id: str = ""
class ClientConfig(BaseClientConfig):
runpod_id: str = ""
@register()
class RunPodVLLMClient(ClientBase):
client_type = "runpod_vllm"
@@ -49,7 +52,6 @@ class RunPodVLLMClient(ClientBase):
)
}
def __init__(self, model=None, runpod_id=None, **kwargs):
self.model_name = model
self.runpod_id = runpod_id
@@ -59,12 +61,10 @@ class RunPodVLLMClient(ClientBase):
def experimental(self):
return False
def set_client(self, **kwargs):
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
@@ -88,32 +88,37 @@ class RunPodVLLMClient(ClientBase):
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
async with aiohttp.ClientSession() as session:
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
run_request = await endpoint.run({
"input": {
"prompt": prompt,
run_request = await endpoint.run(
{
"input": {
"prompt": prompt,
}
# "parameters": parameters
}
#"parameters": parameters
})
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
)
while (await run_request.status()) not in [
"COMPLETED",
"FAILED",
"CANCELLED",
]:
status = await run_request.status()
log.debug("generate", status=status)
await asyncio.sleep(0.1)
status = await run_request.status()
log.debug("generate", status=status)
response = await run_request.output()
log.debug("generate", response=response)
return response["choices"][0]["tokens"][0]
except Exception as e:
self.log.error("generate error", e=e)
emit(

View File

@@ -9,6 +9,7 @@ class Defaults(pydantic.BaseModel):
api_url: str = "http://localhost:1234"
max_token_length: int = 4096
@register()
class TestClient(ClientBase):
client_type = "test"
@@ -22,14 +23,13 @@ class TestClient(ClientBase):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters: dict, kind: str):
"""
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
This method is called before the prompt is sent to the client, and it allows the client to remove
any parameters that it doesn't support.
"""
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
@@ -41,11 +41,10 @@ class TestClient(ClientBase):
del parameters[key]
async def get_model_name(self):
"""
This should return the name of the model that is being used.
"""
return "Mock test model"
async def generate(self, prompt: str, parameters: dict, kind: str):

View File

@@ -1,6 +1,6 @@
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from starlette.exceptions import HTTPException
@@ -14,6 +14,7 @@ app = FastAPI()
# Serve static files, but exclude the root path
app.mount("/", StaticFiles(directory=dist_dir, html=True), name="static")
@app.get("/", response_class=HTMLResponse)
async def serve_root():
index_path = os.path.join(dist_dir, "index.html")
@@ -24,5 +25,6 @@ async def serve_root():
else:
raise HTTPException(status_code=404, detail="index.html not found")
# This is the ASGI application
application = app
application = app

View File

@@ -65,6 +65,7 @@ dev = [
"mkdocs-material>=9.5.27",
"mkdocs-awesome-pages-plugin>=2.9.2",
"mkdocs-glightbox>=0.4.0",
"pre-commit>=2.13",
]
[project.scripts]
@@ -103,4 +104,4 @@ include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
line_length = 88

5
ruff.toml Normal file
View File

@@ -0,0 +1,5 @@
[lint]
# Disable automatic fix for unused imports (`F401`). We check these manually.
unfixable = ["F401"]
# Ignore E402
extend-ignore = ["E402"]

View File

@@ -1,111 +1,112 @@
def game(TM):
MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions"
MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\""
MSG_HELP = 'Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating "Computer," followed by an instruction. For example ... "Computer, i want to experience being on a derelict spaceship."'
PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION."
PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice. Remind the user that this is an old version of the simulation suite and they should check out version two for a more advanced experience."
CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
CTX_PIN_UNAWARE = (
"Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
)
AUTO_NARRATE_INTERVAL = 10
def parse_sim_call_arguments(call:str) -> str:
def parse_sim_call_arguments(call: str) -> str:
"""
Returns the value between the parentheses of a simulation call
Example:
call = 'change_environment("a house")'
parse_sim_call_arguments(call) -> "a house"
"""
try:
return call.split("(", 1)[1].split(")")[0]
except Exception:
return ""
class SimulationSuite:
def __init__(self):
"""
This is initialized at the beginning of each round of the simulation suite
"""
# do we update the world state at the end of the round
self.update_world_state = False
self.simulation_reset = False
# will keep track of any npcs added during the current round
self.added_npcs = []
TM.log.debug("SIMULATION SUITE INIT!", scene=TM.scene)
self.player_message = TM.scene.last_player_message
self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1)
self.last_processed_call = TM.game_state.get_var(
"instr.lastprocessed_call", -1
)
# determine whether the player / user input is an instruction
# to the simulation computer
#
#
# we do this by checking if the message starts with "Computer,"
self.player_message_is_instruction = (
self.player_message and
self.player_message.raw.lower().startswith("computer") and
not self.player_message.hidden and
not self.last_processed_call > self.player_message.id
self.player_message
and self.player_message.raw.lower().startswith("computer")
and not self.player_message.hidden
and not self.last_processed_call > self.player_message.id
)
def run(self):
"""
Main entry point for the simulation suite
"""
if not TM.game_state.has_var("instr.simulation_stopped"):
# simulation is still running
self.simulation()
self.finalize_round()
def simulation(self):
"""
Simulation suite logic
"""
if not TM.game_state.has_var("instr.simulation_started"):
self.startup()
else:
self.simulation_calls()
if self.update_world_state:
self.run_update_world_state(force=True)
def startup(self):
"""
Scene startup logic
"""
# we are at the beginning of the simulation
TM.signals.status("busy", "Simulation suite powering up.", as_scene_message=True)
TM.signals.status(
"busy", "Simulation suite powering up.", as_scene_message=True
)
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
# add narration for the introduction
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=PROMPT_STARTUP,
emit_message=False
emit_message=False,
)
# add narration for the instructions on how to interact with the simulation
# this is a passthrough since we don't want the AI to paraphrase this
TM.agents.narrator.action_to_narration(
action_name="passthrough",
narration=MSG_HELP
action_name="passthrough", narration=MSG_HELP
)
# create a world state entry letting the AI know that characters
# interacting in the simulation are not aware of the computer or the simulation
TM.agents.world_state.save_world_entry(
@@ -113,37 +114,43 @@ def game(TM):
text=CTX_PIN_UNAWARE,
meta={},
# this should always be pinned
pin=True
pin=True,
)
# set flag that we have started the simulation
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
# signal to the UX that the simulation suite is ready
TM.signals.status("success", "Simulation suite ready", as_scene_message=True)
TM.signals.status(
"success", "Simulation suite ready", as_scene_message=True
)
# we want to update the world state at the end of the round
self.update_world_state = True
def simulation_calls(self):
"""
Calls the simulation suite main prompt to determine the appropriate
simulation calls
"""
# we only process instructions that are not hidden and are not the last processed call
if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call:
if (
not self.player_message_is_instruction
or self.player_message.id == self.last_processed_call
):
return
# First instruction?
if not TM.game_state.has_var("instr.has_issued_instructions"):
# determine the context of the simulation
context_context = TM.agents.creator.determine_content_context_for_description(
description=self.player_message.raw,
context_context = (
TM.agents.creator.determine_content_context_for_description(
description=self.player_message.raw,
)
)
TM.scene.set_content_context(context_context)
# Render the `computer` template and send it to the LLM for processing
# The LLM will return a list of calls that the simulation suite will process
# The calls are pseudo code that the simulation suite will interpret and execute
@@ -153,90 +160,98 @@ def game(TM):
player_instruction=self.player_message.raw,
scene=TM.scene,
)
self.calls = calls = calls.split("\n")
calls = self.prepare_calls(calls)
TM.log.debug("SIMULATION SUITE CALLS", callse=calls)
# calls that are processed
processed = []
for call in calls:
processed_call = self.process_call(call)
if processed_call:
processed.append(processed_call)
if processed:
TM.log.debug("SIMULATION SUITE CALLS", calls=processed)
TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False)
TM.signals.status("busy", "Simulation suite altering environment.", as_scene_message=True)
TM.game_state.set_var(
"instr.has_issued_instructions", "yes", commit=False
)
TM.signals.status(
"busy", "Simulation suite altering environment.", as_scene_message=True
)
compiled = "\n".join(processed)
if not self.simulation_reset and compiled:
# send the compiled calls to the narrator to generate a narrative based
# on them
narration = TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.",
emit_message=True
emit_message=True,
)
# on the first narration we update the scene description and remove any mention of the computer
# or the simulation from the previous narration
is_initial_narration = TM.game_state.get_var("instr.intro_narration", False)
is_initial_narration = TM.game_state.get_var(
"instr.intro_narration", False
)
if not is_initial_narration:
TM.scene.set_description(narration.raw)
TM.scene.set_intro(narration.raw)
TM.log.debug("SIMULATION SUITE: initial narration", intro=narration.raw)
TM.log.debug(
"SIMULATION SUITE: initial narration", intro=narration.raw
)
TM.scene.pop_history(typ="narrator", all=True, reverse=True)
TM.scene.pop_history(typ="director", all=True, reverse=True)
TM.game_state.set_var("instr.intro_narration", True, commit=False)
self.update_world_state = True
self.set_simulation_title(compiled)
def set_simulation_title(self, compiled_calls):
"""
Generates a fitting title for the simulation based on the user's instructions
"""
TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls)
TM.log.debug(
"SIMULATION SUITE: set simulation title",
name=TM.scene.title,
compiled_calls=compiled_calls,
)
if not compiled_calls:
return
if TM.scene.title != "Simulation Suite":
# name already changed, no need to do it again
return
title = TM.agents.creator.contextual_generate_from_args(
"scene:simulation title",
"Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.",
length=75
length=75,
)
title = title.strip('"').strip()
TM.scene.set_title(title)
def prepare_calls(self, calls):
"""
Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both
found, ensure that the `set_player_name` call is processed first by moving it in front of the
`set_player_persona` call.
"""
set_player_name_call_exists = -1
set_player_persona_call_exists = -1
i = 0
for call in calls:
if "set_player_name" in call:
@@ -244,351 +259,445 @@ def game(TM):
elif "set_player_persona" in call:
set_player_persona_call_exists = i
i = i + 1
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
if set_player_name_call_exists > set_player_persona_call_exists:
calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists))
TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls)
calls.insert(
set_player_persona_call_exists,
calls.pop(set_player_name_call_exists),
)
TM.log.debug(
"SIMULATION SUITE: prepare calls - moved set_player_persona call",
calls=calls,
)
return calls
def process_call(self, call:str) -> str:
def process_call(self, call: str) -> str:
"""
Processes a simulation call
Simulation alls are pseudo functions that are called by the simulation suite
We grab the function name by splitting against ( and taking the first element
if the SimulationSuite has a method with the name _call_{function_name} then we call it
if a function name could be found but we do not have a method to call we dont do anything
but we still return it as procssed as the AI can still interpret it as something later on
"""
if "(" not in call:
return None
function_name = call.split("(")[0]
if hasattr(self, f"call_{function_name}"):
TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name)
TM.log.debug(
"SIMULATION SUITE CALL", call=call, function_name=function_name
)
inject = f"The computer executes the function `{call}`"
return getattr(self, f"call_{function_name}")(call, inject)
return call
def call_set_simulation_goal(self, call:str, inject:str) -> str:
def call_set_simulation_goal(self, call: str, inject: str) -> str:
"""
Set's the simulation goal as a permanent pin
"""
TM.signals.status("busy", "Simulation suite setting goal.", as_scene_message=True)
TM.agents.world_state.save_world_entry(
entry_id="sim.goal",
text=self.player_message.raw,
meta={},
pin=True
TM.signals.status(
"busy", "Simulation suite setting goal.", as_scene_message=True
)
TM.agents.world_state.save_world_entry(
entry_id="sim.goal", text=self.player_message.raw, meta={}, pin=True
)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer sets the goal for the simulation.",
)
return call
def call_change_environment(self, call:str, inject:str) -> str:
def call_change_environment(self, call: str, inject: str) -> str:
"""
Simulation changes the environment, this is entirely interpreted by the AI
and we dont need to do any logic on our end, so we just return the call
"""
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer changes the environment of the simulation."
action_description="The computer changes the environment of the simulation.",
)
return call
def call_answer_question(self, call:str, inject:str) -> str:
def call_answer_question(self, call: str, inject: str) -> str:
"""
The player asked the simulation a query, we need to process this and have
the AI produce an answer
"""
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.",
emit_message=True
emit_message=True,
)
def call_set_player_persona(self, call:str, inject:str) -> str:
def call_set_player_persona(self, call: str, inject: str) -> str:
"""
The simulation suite is altering the player persona
"""
player_character = TM.scene.get_player_character()
TM.signals.status("busy", "Simulation suite altering user persona.", as_scene_message=True)
character_attributes = TM.agents.world_state.extract_character_sheet(
name=player_character.name, text=inject, alteration_instructions=self.player_message.raw
TM.signals.status(
"busy", "Simulation suite altering user persona.", as_scene_message=True
)
TM.scene.set_character_attributes(player_character.name, character_attributes)
character_description = TM.agents.creator.determine_character_description(player_character.name)
TM.scene.set_character_description(player_character.name, character_description)
TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description)
character_attributes = TM.agents.world_state.extract_character_sheet(
name=player_character.name,
text=inject,
alteration_instructions=self.player_message.raw,
)
TM.scene.set_character_attributes(
player_character.name, character_attributes
)
character_description = TM.agents.creator.determine_character_description(
player_character.name
)
TM.scene.set_character_description(
player_character.name, character_description
)
TM.log.debug(
"SIMULATION SUITE: transform player",
attributes=character_attributes,
description=character_description,
)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer transforms the player persona."
action_description="The computer transforms the player persona.",
)
return call
def call_set_player_name(self, call:str, inject:str) -> str:
def call_set_player_name(self, call: str, inject: str) -> str:
"""
The simulation suite is altering the player name
"""
player_character = TM.scene.get_player_character()
TM.signals.status("busy", "Simulation suite adjusting user identity.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.")
TM.signals.status(
"busy",
"Simulation suite adjusting user identity.",
as_scene_message=True,
)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits."
)
TM.log.debug("SIMULATION SUITE: player name", character_name=character_name)
if character_name != player_character.name:
TM.scene.set_character_name(player_character.name, character_name)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer changes the player's identity to {character_name}."
action_description=f"The computer changes the player's identity to {character_name}.",
)
return call
def call_add_ai_character(self, call:str, inject:str) -> str:
return call
def call_add_ai_character(self, call: str, inject: str) -> str:
# sometimes the AI will call this function an pass an inanimate object as the parameter
# we need to determine if this is the case and just ignore it
is_inanimate = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call)
is_inanimate = TM.agents.world_state.answer_query_true_or_false(
f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)",
call,
)
if is_inanimate:
TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call)
TM.log.debug(
"SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped",
call=call,
)
return
# sometimes the AI will ask if the function adds a group of characters, we need to
# determine if this is the case
adds_group = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add MULTIPLE ai characters?", call)
adds_group = TM.agents.world_state.answer_query_true_or_false(
f"does the function `{call}` add MULTIPLE ai characters?", call
)
TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group)
TM.signals.status("busy", "Simulation suite adding character.", as_scene_message=True)
TM.signals.status(
"busy", "Simulation suite adding character.", as_scene_message=True
)
if not adds_group:
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.")
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name."
)
else:
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.",
group=True,
)
# sometimes add_ai_character and change_ai_character are called in the same instruction targeting
# the same character, if this happens we need to combine into a single add_ai_character call
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls))
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(
f"Are there any calls to `change_ai_character` in the instruction for {character_name}?",
"\n".join(self.calls),
)
if has_change_ai_character_call:
combined_arg = TM.prompt.request(
"combine-add-and-alter-ai-character",
dedupe_enabled=False,
calls="\n".join(self.calls),
character_name=character_name,
scene=TM.scene,
).replace("COMBINED ARGUMENT:", "").strip()
combined_arg = (
TM.prompt.request(
"combine-add-and-alter-ai-character",
dedupe_enabled=False,
calls="\n".join(self.calls),
character_name=character_name,
scene=TM.scene,
)
.replace("COMBINED ARGUMENT:", "")
.strip()
)
call = f"add_ai_character({combined_arg})"
inject = f"The computer executes the function `{call}`"
TM.signals.status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True)
TM.signals.status(
"busy",
f"Simulation suite adding character: {character_name}",
as_scene_message=True,
)
TM.log.debug("SIMULATION SUITE: add npc", name=character_name)
npc = TM.agents.director.persist_character(character_name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False)
npc = TM.agents.director.persist_character(
character_name=character_name,
content=self.player_message.raw + f"\n\n{inject}",
determine_name=False,
)
self.added_npcs.append(npc.name)
TM.agents.world_state.add_detail_reinforcement(
character_name=npc.name,
detail="Goal",
instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation",
interval=25,
run_immediately=True
run_immediately=True,
)
TM.log.debug("SIMULATION SUITE: added npc", npc=npc)
TM.agents.visual.generate_character_portrait(character_name=npc.name)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer adds {npc.name} to the simulation."
action_description=f"The computer adds {npc.name} to the simulation.",
)
return call
return call
####
def call_remove_ai_character(self, call:str, inject:str) -> str:
TM.signals.status("busy", "Simulation suite removing character.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names)
def call_remove_ai_character(self, call: str, inject: str) -> str:
TM.signals.status(
"busy", "Simulation suite removing character.", as_scene_message=True
)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character being removed?",
allowed_names=TM.scene.npc_character_names,
)
npc = TM.scene.get_character(character_name)
if npc:
TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name)
TM.agents.world_state.deactivate_character(action_name="deactivate_character", character_name=npc.name)
TM.agents.world_state.deactivate_character(
action_name="deactivate_character", character_name=npc.name
)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer removes {npc.name} from the simulation."
action_description=f"The computer removes {npc.name} from the simulation.",
)
return call
def call_change_ai_character(self, call:str, inject:str) -> str:
TM.signals.status("busy", "Simulation suite altering character.", as_scene_message=True)
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names)
def call_change_ai_character(self, call: str, inject: str) -> str:
TM.signals.status(
"busy", "Simulation suite altering character.", as_scene_message=True
)
character_name = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?",
allowed_names=TM.scene.npc_character_names,
)
if character_name in self.added_npcs:
# we dont want to change the character if it was just added
return
character_name_after = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?")
character_name_after = TM.agents.creator.determine_character_name(
instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?"
)
npc = TM.scene.get_character(character_name)
if npc:
TM.signals.status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True)
TM.signals.status(
"busy",
f"Changing {character_name} -> {character_name_after}",
as_scene_message=True,
)
TM.log.debug("SIMULATION SUITE: transform npc", npc=npc)
character_attributes = TM.agents.world_state.extract_character_sheet(
name=npc.name,
text=inject,
alteration_instructions=self.player_message.raw
alteration_instructions=self.player_message.raw,
)
TM.scene.set_character_attributes(npc.name, character_attributes)
character_description = TM.agents.creator.determine_character_description(npc.name)
character_description = (
TM.agents.creator.determine_character_description(npc.name)
)
TM.scene.set_character_description(npc.name, character_description)
TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description)
TM.log.debug(
"SIMULATION SUITE: transform npc",
attributes=character_attributes,
description=character_description,
)
if character_name_after != character_name:
TM.scene.set_character_name(npc.name, character_name_after)
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description=f"The computer transforms {npc.name}."
action_description=f"The computer transforms {npc.name}.",
)
return call
def call_end_simulation(self, call:str, inject:str) -> str:
def call_end_simulation(self, call: str, inject: str) -> str:
player_character = TM.scene.get_player_character()
explicit_command = TM.agents.world_state.answer_query_true_or_false("has the player explicitly asked to end the simulation?", self.player_message.raw)
explicit_command = TM.agents.world_state.answer_query_true_or_false(
"has the player explicitly asked to end the simulation?",
self.player_message.raw,
)
if explicit_command:
TM.signals.status("busy", "Simulation suite ending current simulation.", as_scene_message=True)
TM.signals.status(
"busy",
"Simulation suite ending current simulation.",
as_scene_message=True,
)
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names)}. The player is also transformed back to their normal, non-descript persona as the form of {player_character.name} ceases to exist.",
emit_message=True
emit_message=True,
)
TM.scene.restore()
self.simulation_reset = True
TM.game_state.unset_var("instr.has_issued_instructions")
TM.game_state.unset_var("instr.lastprocessed_call")
TM.game_state.unset_var("instr.simulation_started")
TM.agents.director.log_action(
action=parse_sim_call_arguments(call),
action_description="The computer ends the simulation."
action_description="The computer ends the simulation.",
)
def finalize_round(self):
# track rounds
rounds = TM.game_state.get_var("instr.rounds", 0)
# increase rounds
TM.game_state.set_var("instr.rounds", rounds + 1, commit=False)
has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions")
has_issued_instructions = TM.game_state.has_var(
"instr.has_issued_instructions"
)
if self.update_world_state:
self.run_update_world_state()
if self.player_message_is_instruction:
TM.scene.hide_message(self.player_message.id)
TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False)
TM.signals.status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True)
TM.game_state.set_var(
"instr.lastprocessed_call", self.player_message.id, commit=False
)
TM.signals.status(
"success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True
)
elif self.player_message and not has_issued_instructions:
# simulation started, player message is NOT an instruction, and player has not given
# any instructions
self.guide_player()
elif self.player_message and not TM.scene.npc_character_names:
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
self.narrate_round()
elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names and has_issued_instructions:
elif (
rounds % AUTO_NARRATE_INTERVAL == 0
and rounds
and TM.scene.npc_character_names
and has_issued_instructions
):
# every N rounds, narrate the round
self.narrate_round()
def guide_player(self):
TM.agents.narrator.action_to_narration(
action_name="paraphrase",
narration=MSG_HELP,
emit_message=True
action_name="paraphrase", narration=MSG_HELP, emit_message=True
)
def narrate_round(self):
TM.agents.narrator.action_to_narration(
action_name="progress_story",
narrative_direction=PROMPT_NARRATE_ROUND,
emit_message=True
emit_message=True,
)
def run_update_world_state(self, force=False):
TM.log.debug("SIMULATION SUITE: update world state", force=force)
TM.signals.status("busy", "Simulation suite updating world state.", as_scene_message=True)
TM.signals.status(
"busy", "Simulation suite updating world state.", as_scene_message=True
)
TM.agents.world_state.update_world_state(force=force)
TM.signals.status("success", "Simulation suite updated world state.", as_scene_message=True)
TM.signals.status(
"success",
"Simulation suite updated world state.",
as_scene_message=True,
)
SimulationSuite().run()
def on_generation_cancelled(TM, exc):
"""
Called when user pressed the cancel button during the simulation suite
loop.
"""
TM.signals.status("success", "Simulation suite instructions cancelled", as_scene_message=True)
TM.signals.status(
"success", "Simulation suite instructions cancelled", as_scene_message=True
)
rounds = TM.game_state.get_var("instr.rounds", 0)
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)

View File

@@ -1,5 +1,5 @@
from .tale_mate import *
from .tale_mate import * # noqa: F401, F403
from .version import VERSION
__version__ = VERSION
__version__ = VERSION

View File

@@ -1,12 +1,12 @@
from .base import Agent
from .conversation import ConversationAgent
from .creator import CreatorAgent
from .director import DirectorAgent
from .editor import EditorAgent
from .memory import ChromaDBMemoryAgent, MemoryAgent
from .narrator import NarratorAgent
from .registry import AGENT_CLASSES, get_agent_class, register
from .summarize import SummarizeAgent
from .tts import TTSAgent
from .visual import VisualAgent
from .world_state import WorldStateAgent
from .base import Agent # noqa: F401
from .conversation import ConversationAgent # noqa: F401
from .creator import CreatorAgent # noqa: F401
from .director import DirectorAgent # noqa: F401
from .editor import EditorAgent # noqa: F401
from .memory import ChromaDBMemoryAgent, MemoryAgent # noqa: F401
from .narrator import NarratorAgent # noqa: F401
from .registry import AGENT_CLASSES, get_agent_class, register # noqa: F401
from .summarize import SummarizeAgent # noqa: F401
from .tts import TTSAgent # noqa: F401
from .visual import VisualAgent # noqa: F401
from .world_state import WorldStateAgent # noqa: F401

View File

@@ -6,11 +6,10 @@ from inspect import signature
import re
from abc import ABC
from functools import wraps
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from typing import Callable, Union
import uuid
import pydantic
import structlog
from blinker import signal
import talemate.emit.async_signals
import talemate.instance as instance
@@ -39,6 +38,7 @@ __all__ = [
log = structlog.get_logger("talemate.agents.base")
class AgentActionConditional(pydantic.BaseModel):
attribute: str
value: int | float | str | bool | list[int | float | str | bool] | None = None
@@ -48,6 +48,7 @@ class AgentActionNote(pydantic.BaseModel):
type: str
text: str
class AgentActionConfig(pydantic.BaseModel):
type: str
label: str
@@ -65,7 +66,7 @@ class AgentActionConfig(pydantic.BaseModel):
condition: Union[AgentActionConditional, None] = None
title: Union[str, None] = None
value_migration: Union[Callable, None] = pydantic.Field(default=None, exclude=True)
note_on_value: dict[str, AgentActionNote] = pydantic.Field(default_factory=dict)
class Config:
@@ -85,37 +86,37 @@ class AgentAction(pydantic.BaseModel):
quick_toggle: bool = False
experimental: bool = False
class AgentDetail(pydantic.BaseModel):
value: Union[str, None] = None
description: Union[str, None] = None
icon: Union[str, None] = None
color: str = "grey"
class DynamicInstruction(pydantic.BaseModel):
title: str
content: str
def __str__(self) -> str:
return "\n".join(
[
f"<|SECTION:{self.title}|>",
self.content,
"<|CLOSE_SECTION|>"
]
[f"<|SECTION:{self.title}|>", self.content, "<|CLOSE_SECTION|>"]
)
def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = None) -> dict:
def args_and_kwargs_to_dict(
fn, args: list, kwargs: dict, filter: list[str] = None
) -> dict:
"""
Takes a list of arguments and a dict of keyword arguments and returns
a dict mapping parameter names to their values.
Args:
fn: The function whose parameters we want to map
args: List of positional arguments
kwargs: Dictionary of keyword arguments
filter: List of parameter names to include in the result, if None all parameters are included
Returns:
Dict mapping parameter names to their values
"""
@@ -124,34 +125,36 @@ def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = Non
bound_args.apply_defaults()
rv = dict(bound_args.arguments)
rv.pop("self", None)
if filter:
for key in list(rv.keys()):
if key not in filter:
rv.pop(key)
return rv
class store_context_state:
"""
Flag to store a function's arguments in the agent's context state.
Any arguments passed to the function will be stored in the agent's context
If no arguments are passed, all arguments will be stored.
Keyword arguments can be passed to store additional values in the context state.
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def __call__(self, fn):
fn.store_context_state = self.args
fn.store_context_state_kwargs = self.kwargs
return fn
def set_processing(fn):
"""
decorator that emits the agent status as processing while the function
@@ -165,31 +168,38 @@ def set_processing(fn):
async def wrapper(self, *args, **kwargs):
with ClientContext():
scene = active_scene.get()
if scene:
scene.continue_actions()
if getattr(scene, "config", None):
set_client_context_attribute("app_config_system_prompts", scene.config.get("system_prompts", {}))
set_client_context_attribute(
"app_config_system_prompts", scene.config.get("system_prompts", {})
)
with ActiveAgent(self, fn, args, kwargs) as active_agent_context:
try:
await self.emit_status(processing=True)
# Now pass the complete args list
if getattr(fn, "store_context_state", None) is not None:
all_args = args_and_kwargs_to_dict(
fn, [self] + list(args), kwargs, getattr(fn, "store_context_state", [])
fn,
[self] + list(args),
kwargs,
getattr(fn, "store_context_state", []),
)
if getattr(fn, "store_context_state_kwargs", None) is not None:
all_args.update(getattr(fn, "store_context_state_kwargs", {}))
all_args.update(
getattr(fn, "store_context_state_kwargs", {})
)
all_args[f"fn_{fn.__name__}"] = True
active_agent_context.state_params = all_args
self.set_context_states(**all_args)
return await fn(self, *args, **kwargs)
finally:
try:
@@ -215,14 +225,16 @@ class Agent(ABC):
websocket_handler = None
essential = True
ready_check_error = None
@classmethod
def init_actions(cls, actions: dict[str, AgentAction] | None = None) -> dict[str, AgentAction]:
def init_actions(
cls, actions: dict[str, AgentAction] | None = None
) -> dict[str, AgentAction]:
if actions is None:
actions = {}
return actions
@property
def agent_details(self):
if hasattr(self, "client"):
@@ -230,10 +242,6 @@ class Agent(ABC):
return self.client.name
return None
@property
def verbose_name(self):
return self.agent_type.capitalize()
@property
def ready(self):
if not getattr(self.client, "enabled", True):
@@ -315,67 +323,68 @@ class Agent(ABC):
return {}
return {k: v.model_dump() for k, v in self.actions.items()}
# scene state
def context_fingerpint(self, extra: list[str] = []) -> str | None:
active_agent_context = active_agent.get()
if not active_agent_context:
return None
if self.scene.history:
fingerprint = f"{self.scene.history[-1].fingerprint}-{active_agent_context.first.fingerprint}"
else:
fingerprint = f"START-{active_agent_context.first.fingerprint}"
for extra_key in extra:
fingerprint += f"-{hash(extra_key)}"
return fingerprint
def get_scene_state(self, key:str, default=None):
def get_scene_state(self, key: str, default=None):
agent_state = self.scene.agent_state.get(self.agent_type, {})
return agent_state.get(key, default)
def set_scene_states(self, **kwargs):
agent_state = self.scene.agent_state.get(self.agent_type, {})
for key, value in kwargs.items():
agent_state[key] = value
self.scene.agent_state[self.agent_type] = agent_state
def dump_scene_state(self):
return self.scene.agent_state.get(self.agent_type, {})
# active agent context state
def get_context_state(self, key:str, default=None):
def get_context_state(self, key: str, default=None):
key = f"{self.agent_type}__{key}"
try:
return active_agent.get().state.get(key, default)
except AttributeError:
log.warning("get_context_state error", agent=self.agent_type, key=key)
return default
def set_context_states(self, **kwargs):
try:
items = {f"{self.agent_type}__{k}": v for k, v in kwargs.items()}
active_agent.get().state.update(items)
log.debug("set_context_states", agent=self.agent_type, state=active_agent.get().state)
log.debug(
"set_context_states",
agent=self.agent_type,
state=active_agent.get().state,
)
except AttributeError:
log.error("set_context_states error", agent=self.agent_type, kwargs=kwargs)
def dump_context_state(self):
try:
return active_agent.get().state
except AttributeError:
return {}
###
async def _handle_ready_check(self, fut: asyncio.Future):
callback_failure = getattr(self, "on_ready_check_failure", None)
if fut.cancelled():
@@ -425,41 +434,50 @@ class Agent(ABC):
if not action.config:
continue
for config_key, config in action.config.items():
for config_key, _config in action.config.items():
try:
config.value = (
_config.value = (
kwargs.get("actions", {})
.get(action_key, {})
.get("config", {})
.get(config_key, {})
.get("value", config.value)
.get("value", _config.value)
)
if config.value_migration and callable(config.value_migration):
config.value = config.value_migration(config.value)
if _config.value_migration and callable(_config.value_migration):
_config.value = _config.value_migration(_config.value)
except AttributeError:
pass
async def save_config(self, app_config: config.Config | None = None):
"""
Saves the agent config to the config file.
If no config object is provided, the config is loaded from the config file.
"""
if not app_config:
app_config:config.Config = config.load_config(as_model=True)
app_config: config.Config = config.load_config(as_model=True)
app_config.agents[self.agent_type] = config.Agent(
name=self.agent_type,
client=self.client.name if self.client else None,
enabled=self.enabled,
actions={action_key: config.AgentAction(
enabled=action.enabled,
config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()}
) for action_key, action in self.actions.items()}
actions={
action_key: config.AgentAction(
enabled=action.enabled,
config={
config_key: config.AgentActionConfig(value=config_obj.value)
for config_key, config_obj in action.config.items()
},
)
for action_key, action in self.actions.items()
},
)
log.debug(
"saving agent config",
agent=self.agent_type,
config=app_config.agents[self.agent_type],
)
log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type])
config.save_config(app_config)
async def on_game_loop_start(self, event: GameLoopStartEvent):
@@ -474,19 +492,19 @@ class Agent(ABC):
if not action.config:
continue
for _, config in action.config.items():
if config.scope == "scene":
for _, _config in action.config.items():
if _config.scope == "scene":
# if default_value is None, just use the `type` of the current
# value
if config.default_value is None:
default_value = type(config.value)()
if _config.default_value is None:
default_value = type(_config.value)()
else:
default_value = config.default_value
default_value = _config.default_value
log.debug(
"resetting config", config=config, default_value=default_value
"resetting config", config=_config, default_value=default_value
)
config.value = default_value
_config.value = default_value
await self.emit_status()
@@ -518,7 +536,9 @@ class Agent(ABC):
await asyncio.sleep(0.01)
async def _handle_background_processing(self, fut: asyncio.Future, error_handler = None):
async def _handle_background_processing(
self, fut: asyncio.Future, error_handler=None
):
try:
if fut.cancelled():
return
@@ -541,7 +561,7 @@ class Agent(ABC):
self.processing_bg -= 1
await self.emit_status()
async def set_background_processing(self, task: asyncio.Task, error_handler = None):
async def set_background_processing(self, task: asyncio.Task, error_handler=None):
log.info("set_background_processing", agent=self.agent_type)
if not hasattr(self, "processing_bg"):
self.processing_bg = 0
@@ -550,7 +570,9 @@ class Agent(ABC):
await self.emit_status()
task.add_done_callback(
lambda fut: asyncio.create_task(self._handle_background_processing(fut, error_handler))
lambda fut: asyncio.create_task(
self._handle_background_processing(fut, error_handler)
)
)
def connect(self, scene):
@@ -623,7 +645,6 @@ class Agent(ABC):
"""
return False
@set_processing
async def delegate(self, fn: Callable, *args, **kwargs):
"""
@@ -631,20 +652,22 @@ class Agent(ABC):
by the agent.
"""
return await fn(*args, **kwargs)
async def emit_message(self, header:str, message:str | list[dict], meta: dict = None, **data):
async def emit_message(
self, header: str, message: str | list[dict], meta: dict = None, **data
):
if not data:
data = {}
if not meta:
meta = {}
if "uuid" not in data:
data["uuid"] = str(uuid.uuid4())
if "agent" not in data:
data["agent"] = self.agent_type
data["header"] = header
emit(
"agent_message",
@@ -653,17 +676,22 @@ class Agent(ABC):
meta=meta,
websocket_passthrough=True,
)
@dataclasses.dataclass
class AgentEmission:
agent: Agent
@dataclasses.dataclass
class AgentTemplateEmission(AgentEmission):
template_vars: dict = dataclasses.field(default_factory=dict)
response: str = None
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
@dataclasses.dataclass
class RagBuildSubInstructionEmission(AgentEmission):
sub_instruction: str | None = None

View File

@@ -1,13 +1,9 @@
import contextvars
import uuid
import hashlib
from typing import TYPE_CHECKING, Callable
from typing import Callable
import pydantic
if TYPE_CHECKING:
from talemate.tale_mate import Character
__all__ = [
"active_agent",
]
@@ -25,7 +21,6 @@ class ActiveAgentContext(pydantic.BaseModel):
state: dict = pydantic.Field(default_factory=dict)
state_params: dict = pydantic.Field(default_factory=dict)
previous: "ActiveAgentContext" = None
class Config:
arbitrary_types_allowed = True
@@ -35,29 +30,30 @@ class ActiveAgentContext(pydantic.BaseModel):
return self.previous.first if self.previous else self
@property
def action(self):
def action(self):
name = self.fn.__name__
if name == "delegate":
return self.fn_args[0].__name__
return name
@property
def fingerprint(self) -> int:
if hasattr(self, "_fingerprint"):
return self._fingerprint
self._fingerprint = hash(frozenset(self.state_params.items()))
return self._fingerprint
def __str__(self):
return f"{self.agent.verbose_name}.{self.action}"
class ActiveAgent:
def __init__(self, agent, fn, args=None, kwargs=None):
self.agent = ActiveAgentContext(agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {})
self.agent = ActiveAgentContext(
agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {}
)
def __enter__(self):
previous_agent = active_agent.get()
if previous_agent:
@@ -70,7 +66,7 @@ class ActiveAgent:
self.agent.agent_stack_uid = str(uuid.uuid4())
self.token = active_agent.set(self.agent)
return self.agent
def __exit__(self, *args, **kwargs):

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import dataclasses
import random
import re
from datetime import datetime
from typing import TYPE_CHECKING, Optional
@@ -10,14 +9,12 @@ import structlog
import talemate.client as client
import talemate.emit.async_signals
import talemate.instance as instance
import talemate.util as util
from talemate.client.context import (
client_context_attribute,
set_client_context_attribute,
set_conversation_context_attribute,
)
from talemate.events import GameLoopEvent
from talemate.exceptions import LLMAccuracyError
from talemate.prompts import Prompt
from talemate.scene_message import CharacterMessage, DirectorMessage
@@ -37,7 +34,7 @@ from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.agents.context import active_agent
from .websocket_handler import ConversationWebsocketHandler
import talemate.agents.conversation.nodes
import talemate.agents.conversation.nodes # noqa: F401
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Character
@@ -50,21 +47,20 @@ class ConversationAgentEmission(AgentEmission):
actor: Actor
character: Character
response: str
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
talemate.emit.async_signals.register(
"agent.conversation.before_generate",
"agent.conversation.before_generate",
"agent.conversation.inject_instructions",
"agent.conversation.generated"
"agent.conversation.generated",
)
@register()
class ConversationAgent(
MemoryRAGMixin,
Agent
):
class ConversationAgent(MemoryRAGMixin, Agent):
"""
An agent that can be used to have a conversation with the AI
@@ -135,8 +131,6 @@ class ConversationAgent(
max=20,
step=1,
),
},
),
"auto_break_repetition": AgentAction(
@@ -159,7 +153,7 @@ class ConversationAgent(
description="Use the writing style selected in the scene settings",
value=True,
),
}
},
),
}
MemoryRAGMixin.add_actions(actions)
@@ -205,7 +199,6 @@ class ConversationAgent(
@property
def agent_details(self) -> dict:
details = {
"client": AgentDetail(
icon="mdi-network-outline",
@@ -231,22 +224,24 @@ class ConversationAgent(
@property
def generation_settings_actor_instructions_offset(self):
return self.actions["generation_override"].config["actor_instructions_offset"].value
return (
self.actions["generation_override"]
.config["actor_instructions_offset"]
.value
)
@property
def generation_settings_response_length(self):
return self.actions["generation_override"].config["length"].value
@property
def generation_settings_override_enabled(self):
return self.actions["generation_override"].enabled
@property
def content_use_writing_style(self) -> bool:
return self.actions["content"].config["use_writing_style"].value
def connect(self, scene):
super().connect(scene)
@@ -276,7 +271,7 @@ class ConversationAgent(
main_character = scene.main_character.character
character_names = [c.name for c in scene.characters]
if main_character:
try:
character_names.remove(main_character.name)
@@ -296,21 +291,22 @@ class ConversationAgent(
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
except IndexError:
director_message = False
inject_instructions_emission = ConversationAgentEmission(
agent=self,
response="",
actor=None,
character=character,
response="",
actor=None,
character=character,
)
await talemate.emit.async_signals.get(
"agent.conversation.inject_instructions"
).send(inject_instructions_emission)
agent_context = active_agent.get()
agent_context.state["dynamic_instructions"] = inject_instructions_emission.dynamic_instructions
agent_context.state["dynamic_instructions"] = (
inject_instructions_emission.dynamic_instructions
)
conversation_format = self.conversation_format
prompt = Prompt.get(
f"conversation.dialogue-{conversation_format}",
@@ -319,26 +315,30 @@ class ConversationAgent(
"max_tokens": self.client.max_token_length,
"scene_and_dialogue_budget": scene_and_dialogue_budget,
"scene_and_dialogue": scene_and_dialogue,
"memory": None, # DEPRECATED VARIABLE
"memory": None, # DEPRECATED VARIABLE
"characters": list(scene.get_characters()),
"main_character": main_character,
"formatted_names": formatted_names,
"talking_character": character,
"partial_message": char_message,
"director_message": director_message,
"extra_instructions": self.generation_settings_task_instructions, #backward compatibility
"extra_instructions": self.generation_settings_task_instructions, # backward compatibility
"task_instructions": self.generation_settings_task_instructions,
"actor_instructions": self.generation_settings_actor_instructions,
"actor_instructions_offset": self.generation_settings_actor_instructions_offset,
"direct_instruction": instruction,
"decensor": self.client.decensor_enabled,
"response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None,
"response_length": self.generation_settings_response_length
if self.generation_settings_override_enabled
else None,
},
)
return str(prompt)
async def build_prompt(self, character, char_message: str = "", instruction:str = None):
async def build_prompt(
self, character, char_message: str = "", instruction: str = None
):
fn = self.build_prompt_default
return await fn(character, char_message=char_message, instruction=instruction)
@@ -376,12 +376,12 @@ class ConversationAgent(
set_client_context_attribute("nuke_repetition", nuke_repetition)
@set_processing
@store_context_state('instruction')
@store_context_state("instruction")
async def converse(
self,
self,
actor,
instruction:str = None,
emit_signals:bool = True,
instruction: str = None,
emit_signals: bool = True,
) -> list[CharacterMessage]:
"""
Have a conversation with the AI
@@ -398,7 +398,9 @@ class ConversationAgent(
self.set_generation_overrides()
result = await self.client.send_prompt(await self.build_prompt(character, instruction=instruction))
result = await self.client.send_prompt(
await self.build_prompt(character, instruction=instruction)
)
result = self.clean_result(result, character)
@@ -454,7 +456,7 @@ class ConversationAgent(
# movie script format
# {uppercase character name}
# {dialogue}
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
total_result = total_result.replace(f"{character.name.upper()}\n", "")
# chat format
# {character name}: {dialogue}
@@ -464,7 +466,7 @@ class ConversationAgent(
total_result = util.clean_dialogue(total_result, main_name=character.name)
# Check if total_result starts with character name, if not, prepend it
if not total_result.startswith(character.name+":"):
if not total_result.startswith(character.name + ":"):
total_result = f"{character.name}: {total_result}"
total_result = total_result.strip()
@@ -481,11 +483,11 @@ class ConversationAgent(
log.debug("conversation agent", response=response)
emission = ConversationAgentEmission(
agent=self,
actor=actor,
character=character,
agent=self,
actor=actor,
character=character,
response=response,
)
)
if emit_signals:
await talemate.emit.async_signals.get("agent.conversation.generated").send(
emission

View File

@@ -1,74 +1,75 @@
import structlog
from typing import TYPE_CHECKING, ClassVar
from talemate.game.engine.nodes.core import Node, GraphState, UNRESOLVED
from talemate.game.engine.nodes.core import GraphState, UNRESOLVED
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentNode, AgentSettingsNode
from talemate.context import active_scene
from talemate.client.context import ConversationContext, ClientContext
import talemate.events as events
if TYPE_CHECKING:
from talemate.tale_mate import Scene, Character
log = structlog.get_logger("talemate.game.engine.nodes.agents.conversation")
@register("agents/conversation/Settings")
class ConversationSettings(AgentSettingsNode):
"""
Base node to render conversation agent settings.
"""
_agent_name:ClassVar[str] = "conversation"
_agent_name: ClassVar[str] = "conversation"
def __init__(self, title="Conversation Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/conversation/Generate")
class GenerateConversation(AgentNode):
"""
Generate a conversation between two characters
"""
_agent_name:ClassVar[str] = "conversation"
_agent_name: ClassVar[str] = "conversation"
def __init__(self, title="Generate Conversation", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_input("instruction", socket_type="str", optional=True)
self.set_property("trigger_conversation_generated", True)
self.add_output("generated", socket_type="str")
self.add_output("message", socket_type="message_object")
async def run(self, state: GraphState):
character:"Character" = self.get_input_value("character")
scene:"Scene" = active_scene.get()
character: "Character" = self.get_input_value("character")
scene: "Scene" = active_scene.get()
instruction = self.get_input_value("instruction")
trigger_conversation_generated = self.get_property("trigger_conversation_generated")
trigger_conversation_generated = self.get_property(
"trigger_conversation_generated"
)
other_characters = [c.name for c in scene.characters if c != character]
conversation_context = ConversationContext(
talking_character=character.name,
other_characters=other_characters,
)
if instruction == UNRESOLVED:
instruction = None
with ClientContext(conversation=conversation_context):
messages = await self.agent.converse(
character.actor,
character.actor,
instruction=instruction,
emit_signals=trigger_conversation_generated,
)
message = messages[0]
self.set_output_values({
"generated": message.message,
"message": message
})
self.set_output_values({"generated": message.message, "message": message})

View File

@@ -18,23 +18,25 @@ __all__ = [
log = structlog.get_logger("talemate.server.conversation")
class RequestActorActionPayload(pydantic.BaseModel):
character:str = ""
instructions:str = ""
emit_signals:bool = True
instructions_through_director:bool = True
character: str = ""
instructions: str = ""
emit_signals: bool = True
instructions_through_director: bool = True
class ConversationWebsocketHandler(Plugin):
"""
Handles narrator actions
"""
router = "conversation"
@property
def agent(self) -> "ConversationAgent":
return get_agent("conversation")
@set_loading("Generating actor action", cancellable=True, as_async=True)
async def handle_request_actor_action(self, data: dict):
"""
@@ -43,38 +45,37 @@ class ConversationWebsocketHandler(Plugin):
payload = RequestActorActionPayload(**data)
character = None
actor = None
if payload.character:
character = self.scene.get_character(payload.character)
actor = character.actor
else:
actor = random.choice(list(self.scene.get_npc_characters())).actor
if not actor:
log.error("handle_request_actor_action: No actor found")
return
character = actor.character
if payload.instructions_through_director:
director_message = DirectorMessage(
payload.instructions,
source="player",
meta={"character": character.name}
meta={"character": character.name},
)
emit("director", message=director_message, character=character)
self.scene.push_history(director_message)
generated_messages = await self.agent.converse(
actor,
emit_signals=payload.emit_signals
actor, emit_signals=payload.emit_signals
)
else:
generated_messages = await self.agent.converse(
actor,
actor,
instruction=payload.instructions,
emit_signals=payload.emit_signals
emit_signals=payload.emit_signals,
)
for message in generated_messages:
self.scene.push_history(message)
emit("character", message=message, character=character)
emit("character", message=message, character=character)

View File

@@ -1,13 +1,9 @@
from __future__ import annotations
import json
import os
import talemate.client as client
from talemate.agents.base import Agent, set_processing
from talemate.agents.registry import register
from talemate.agents.memory.rag import MemoryRAGMixin
from talemate.emit import emit
from talemate.prompts import Prompt
from .assistant import AssistantMixin
@@ -16,7 +12,8 @@ from .scenario import ScenarioCreatorMixin
from talemate.agents.base import AgentAction
import talemate.agents.creator.nodes
import talemate.agents.creator.nodes # noqa: F401
@register()
class CreatorAgent(
@@ -51,7 +48,7 @@ class CreatorAgent(
@set_processing
async def generate_title(self, text: str):
title = await Prompt.request(
f"creator.generate-title",
"creator.generate-title",
self.client,
"create_short",
vars={

View File

@@ -40,6 +40,7 @@ async_signals.register(
"agent.creator.autocomplete.after",
)
@dataclasses.dataclass
class ContextualGenerateEmission(AgentTemplateEmission):
"""
@@ -48,15 +49,16 @@ class ContextualGenerateEmission(AgentTemplateEmission):
content_generation_context: "ContentGenerationContext | None" = None
character: "Character | None" = None
@property
def context_type(self) -> str:
return self.content_generation_context.computed_context[0]
@property
def context_name(self) -> str:
return self.content_generation_context.computed_context[1]
@dataclasses.dataclass
class AutocompleteEmission(AgentTemplateEmission):
"""
@@ -67,6 +69,7 @@ class AutocompleteEmission(AgentTemplateEmission):
type: str = ""
character: "Character | None" = None
class ContentGenerationContext(pydantic.BaseModel):
"""
A context for generating content.
@@ -104,7 +107,6 @@ class ContentGenerationContext(pydantic.BaseModel):
@property
def spice(self) -> str:
spice_level = self.generation_options.spice_level
if self.template and not getattr(self.template, "supports_spice", False):
@@ -148,7 +150,6 @@ class ContentGenerationContext(pydantic.BaseModel):
@property
def style(self):
if self.template and not getattr(self.template, "supports_style", False):
# template supplied that doesn't support style
return ""
@@ -165,11 +166,12 @@ class ContentGenerationContext(pydantic.BaseModel):
def get_state(self, key: str) -> str | int | float | bool | None:
return self.state.get(key)
class AssistantMixin:
"""
Creator mixin that allows quick contextual generation of content.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["autocomplete"] = AgentAction(
@@ -198,15 +200,15 @@ class AssistantMixin:
max=256,
step=16,
),
}
},
)
# property helpers
@property
def autocomplete_dialogue_suggestion_length(self):
return self.actions["autocomplete"].config["dialogue_suggestion_length"].value
@property
def autocomplete_narrative_suggestion_length(self):
return self.actions["autocomplete"].config["narrative_suggestion_length"].value
@@ -255,7 +257,7 @@ class AssistantMixin:
history_aware=history_aware,
information=information,
)
for key, value in kwargs.items():
generation_context.set_state(key, value)
@@ -279,9 +281,13 @@ class AssistantMixin:
f"Contextual generate: {context_typ} - {context_name}",
generation_context=generation_context,
)
character = self.scene.get_character(generation_context.character) if generation_context.character else None
character = (
self.scene.get_character(generation_context.character)
if generation_context.character
else None
)
template_vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
@@ -295,27 +301,29 @@ class AssistantMixin:
"character": character,
"template": generation_context.template,
}
emission = ContextualGenerateEmission(
agent=self,
content_generation_context=generation_context,
character=character,
template_vars=template_vars,
)
await async_signals.get("agent.creator.contextual_generate.before").send(emission)
await async_signals.get("agent.creator.contextual_generate.before").send(
emission
)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
content = await Prompt.request(
f"creator.contextual-generate",
"creator.contextual-generate",
self.client,
kind,
vars=template_vars,
)
emission.response = content
if not generation_context.partial:
content = util.strip_partial_sentences(content)
@@ -329,22 +337,29 @@ class AssistantMixin:
if not content.startswith(generation_context.character + ":"):
content = generation_context.character + ": " + content
content = util.strip_partial_sentences(content)
character = self.scene.get_character(generation_context.character)
if not character:
log.warning("Character not found", character=generation_context.character)
return content
emission.response = await editor.cleanup_character_message(content, character)
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
return emission.response
emission.response = content.strip().strip("*").strip()
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
return emission.response
character = self.scene.get_character(generation_context.character)
if not character:
log.warning(
"Character not found", character=generation_context.character
)
return content
emission.response = await editor.cleanup_character_message(
content, character
)
await async_signals.get("agent.creator.contextual_generate.after").send(
emission
)
return emission.response
emission.response = content.strip().strip("*").strip()
await async_signals.get("agent.creator.contextual_generate.after").send(
emission
)
return emission.response
@set_processing
async def generate_character_attribute(
@@ -357,10 +372,10 @@ class AssistantMixin:
) -> str:
"""
Wrapper for contextual_generate that generates a character attribute.
"""
"""
if not generation_options:
generation_options = GenerationOptions()
return await self.contextual_generate_from_args(
context=f"character attribute:{attribute_name}",
character=character.name,
@@ -368,7 +383,7 @@ class AssistantMixin:
original=original,
**generation_options.model_dump(),
)
@set_processing
async def generate_character_detail(
self,
@@ -381,11 +396,11 @@ class AssistantMixin:
) -> str:
"""
Wrapper for contextual_generate that generates a character detail.
"""
"""
if not generation_options:
generation_options = GenerationOptions()
return await self.contextual_generate_from_args(
context=f"character detail:{detail_name}",
character=character.name,
@@ -394,7 +409,7 @@ class AssistantMixin:
length=length,
**generation_options.model_dump(),
)
@set_processing
async def generate_thematic_list(
self,
@@ -408,11 +423,11 @@ class AssistantMixin:
"""
if not generation_options:
generation_options = GenerationOptions()
i = 0
result = []
while i < iterations:
i += 1
_result = await self.contextual_generate_from_args(
@@ -420,14 +435,14 @@ class AssistantMixin:
instructions=instructions,
length=length,
original="\n".join(result) if result else None,
extend=i>1,
extend=i > 1,
**generation_options.model_dump(),
)
_result = json.loads(_result)
result = list(set(result + _result))
return result
@set_processing
@@ -443,34 +458,34 @@ class AssistantMixin:
"""
if not response_length:
response_length = self.autocomplete_dialogue_suggestion_length
# continuing recent character message
non_anchor, anchor = util.split_anchor_text(input, 10)
self.scene.log.debug(
"autocomplete_anchor",
anchor=anchor,
non_anchor=non_anchor,
input=input
"autocomplete_anchor", anchor=anchor, non_anchor=non_anchor, input=input
)
continuing_message = False
message = None
try:
message = self.scene.history[-1]
if isinstance(message, CharacterMessage) and message.character_name == character.name:
if (
isinstance(message, CharacterMessage)
and message.character_name == character.name
):
continuing_message = input.strip() == message.without_name.strip()
except IndexError:
pass
if input.strip().endswith('"'):
prefix = ' *'
elif input.strip().endswith('*'):
prefix = " *"
elif input.strip().endswith("*"):
prefix = ' "'
else:
prefix = ''
prefix = ""
template_vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
@@ -484,7 +499,7 @@ class AssistantMixin:
"non_anchor": non_anchor,
"prefix": prefix,
}
emission = AutocompleteEmission(
agent=self,
input=input,
@@ -492,13 +507,13 @@ class AssistantMixin:
character=character,
template_vars=template_vars,
)
await async_signals.get("agent.creator.autocomplete.before").send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"creator.autocomplete-dialogue",
"creator.autocomplete-dialogue",
self.client,
f"create_{response_length}",
vars=template_vars,
@@ -506,21 +521,22 @@ class AssistantMixin:
dedupe_enabled=False,
)
response = response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
response = (
response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
)
if prefix:
response = prefix + response
emission.response = response
await async_signals.get("agent.creator.autocomplete.after").send(emission)
if not response:
if emit_signal:
emit("autocomplete_suggestion", "")
return ""
response = util.strip_partial_sentences(response).replace("*", "")
if response.startswith(input):
@@ -550,14 +566,14 @@ class AssistantMixin:
# Split the input text into non-anchor and anchor parts
non_anchor, anchor = util.split_anchor_text(input, 10)
self.scene.log.debug(
"autocomplete_narrative_anchor",
"autocomplete_narrative_anchor",
anchor=anchor,
non_anchor=non_anchor,
input=input
input=input,
)
template_vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
@@ -567,20 +583,20 @@ class AssistantMixin:
"anchor": anchor,
"non_anchor": non_anchor,
}
emission = AutocompleteEmission(
agent=self,
input=input,
type="narrative",
template_vars=template_vars,
)
await async_signals.get("agent.creator.autocomplete.before").send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"creator.autocomplete-narrative",
"creator.autocomplete-narrative",
self.client,
f"create_{response_length}",
vars=template_vars,
@@ -593,7 +609,7 @@ class AssistantMixin:
response = response[len(input) :]
emission.response = response
await async_signals.get("agent.creator.autocomplete.after").send(emission)
self.scene.log.debug(
@@ -614,78 +630,75 @@ class AssistantMixin:
"""
Allows to fork a new scene from a specific message
in the current scene.
All content after the message will be removed and the
context database will be re imported ensuring a clean state.
All state reinforcements will be reset to their most recent
state before the message.
"""
emit("status", "Creating scene fork ...", status="busy")
try:
if not save_name:
# build a save name
uuid_str = str(uuid.uuid4())[:8]
save_name = f"{uuid_str}-forked"
log.info(f"Forking scene", message_id=message_id, save_name=save_name)
log.info("Forking scene", message_id=message_id, save_name=save_name)
world_state = get_agent("world_state")
# does a message with the given id exist?
index = self.scene.message_index(message_id)
if index is None:
raise ValueError(f"Message with id {message_id} not found.")
# truncate scene.history keeping index as the last element
self.scene.history = self.scene.history[:index + 1]
self.scene.history = self.scene.history[: index + 1]
# truncate scene.archived_history keeping the element where `end` is < `index`
# as the last element
self.scene.archived_history = [
x for x in self.scene.archived_history if "end" not in x or x["end"] < index
x
for x in self.scene.archived_history
if "end" not in x or x["end"] < index
]
# the same needs to be done for layered history
# where each layer is truncated based on what's left in the previous layer
# using similar logic as above (checking `end` vs `index`)
# layer 0 checks archived_history
new_layered_history = []
for layer_number, layer in enumerate(self.scene.layered_history):
if layer_number == 0:
index = len(self.scene.archived_history) - 1
else:
index = len(new_layered_history[layer_number - 1]) - 1
new_layer = [
x for x in layer if x["end"] < index
]
new_layer = [x for x in layer if x["end"] < index]
new_layered_history.append(new_layer)
self.scene.layered_history = new_layered_history
# save the scene
await self.scene.save(copy_name=save_name)
log.info(f"Scene forked", save_name=save_name)
log.info("Scene forked", save_name=save_name)
# re-emit history
await self.scene.emit_history()
emit("status", f"Updating world state ...", status="busy")
emit("status", "Updating world state ...", status="busy")
# reset state reinforcements
await world_state.update_reinforcements(force = True, reset= True)
await world_state.update_reinforcements(force=True, reset=True)
# update world state
await self.scene.world_state.request_update()
emit("status", f"Scene forked", status="success")
except Exception as e:
emit("status", "Scene forked", status="success")
except Exception:
log.error("Scene fork failed", exc=traceback.format_exc())
emit("status", "Scene fork failed", status="error")

View File

@@ -7,8 +7,6 @@ import structlog
from talemate.agents.base import set_processing
from talemate.prompts import Prompt
import talemate.game.focal as focal
if TYPE_CHECKING:
from talemate.tale_mate import Character
@@ -18,14 +16,13 @@ DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audien
class CharacterCreatorMixin:
@set_processing
async def determine_content_context_for_character(
self,
character: Character,
):
content_context = await Prompt.request(
f"creator.determine-content-context",
"creator.determine-content-context",
self.client,
"create_192",
vars={
@@ -42,7 +39,7 @@ class CharacterCreatorMixin:
information: str = "",
):
instructions = await Prompt.request(
f"creator.determine-character-dialogue-instructions",
"creator.determine-character-dialogue-instructions",
self.client,
"create_concise",
vars={
@@ -63,7 +60,7 @@ class CharacterCreatorMixin:
character: Character,
):
attributes = await Prompt.request(
f"creator.determine-character-attributes",
"creator.determine-character-attributes",
self.client,
"analyze_long",
vars={
@@ -81,7 +78,7 @@ class CharacterCreatorMixin:
instructions: str = "",
) -> str:
name = await Prompt.request(
f"creator.determine-character-name",
"creator.determine-character-name",
self.client,
"analyze_freeform_short",
vars={
@@ -97,14 +94,14 @@ class CharacterCreatorMixin:
@set_processing
async def determine_character_description(
self,
self,
character: Character,
text: str = "",
instructions: str = "",
information: str = "",
):
description = await Prompt.request(
f"creator.determine-character-description",
"creator.determine-character-description",
self.client,
"create",
vars={
@@ -125,7 +122,7 @@ class CharacterCreatorMixin:
goal_instructions: str,
):
goals = await Prompt.request(
f"creator.determine-character-goals",
"creator.determine-character-goals",
self.client,
"create",
vars={
@@ -141,4 +138,4 @@ class CharacterCreatorMixin:
log.debug("determine_character_goals", goals=goals, character=character)
await character.set_detail("goals", goals.strip())
return goals.strip()
return goals.strip()

View File

@@ -1,145 +1,153 @@
import structlog
from typing import ClassVar, TYPE_CHECKING
from typing import ClassVar
from talemate.context import active_scene
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
UNRESOLVED,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
if TYPE_CHECKING:
from talemate.tale_mate import Scene
log = structlog.get_logger("talemate.game.engine.nodes.agents.creator")
@register("agents/creator/Settings")
class CreatorSettings(AgentSettingsNode):
"""
Base node to render creator agent settings.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
def __init__(self, title="Creator Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/creator/DetermineContentContext")
class DetermineContentContext(AgentNode):
"""
Determines the context for the content creation.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
description = PropertyField(
name="description",
description="Description of the context",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
def __init__(self, title="Determine Content Context", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("description", socket_type="str", optional=True)
self.set_property("description", UNRESOLVED)
self.add_output("content_context", socket_type="str")
async def run(self, state: GraphState):
context = await self.agent.determine_content_context_for_description(
self.require_input("description")
)
self.set_output_values({
"content_context": context
})
self.set_output_values({"content_context": context})
@register("agents/creator/DetermineCharacterDescription")
class DetermineCharacterDescription(AgentNode):
"""
Determines the description for a character.
Inputs:
- state: The current state of the graph
- character: The character to determine the description for
- extra_context: Extra context to use in determining the
Outputs:
- description: The determined description
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
def __init__(self, title="Determine Character Description", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_input("extra_context", socket_type="str", optional=True)
self.add_output("description", socket_type="str")
async def run(self, state: GraphState):
character = self.require_input("character")
extra_context = self.get_input_value("extra_context")
if extra_context is UNRESOLVED:
extra_context = ""
description = await self.agent.determine_character_description(character, extra_context)
self.set_output_values({
"description": description
})
description = await self.agent.determine_character_description(
character, extra_context
)
self.set_output_values({"description": description})
@register("agents/creator/DetermineCharacterDialogueInstructions")
class DetermineCharacterDialogueInstructions(AgentNode):
"""
Determines the dialogue instructions for a character.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
instructions = PropertyField(
name="instructions",
description="Any additional instructions to use in determining the dialogue instructions",
type="text",
default=""
default="",
)
def __init__(self, title="Determine Character Dialogue Instructions", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_input("instructions", socket_type="str", optional=True)
self.set_property("instructions", "")
self.add_output("dialogue_instructions", socket_type="str")
async def run(self, state: GraphState):
character = self.require_input("character")
instructions = self.normalized_input_value("instructions")
dialogue_instructions = await self.agent.determine_character_dialogue_instructions(character, instructions)
self.set_output_values({
"dialogue_instructions": dialogue_instructions
})
dialogue_instructions = (
await self.agent.determine_character_dialogue_instructions(
character, instructions
)
)
self.set_output_values({"dialogue_instructions": dialogue_instructions})
@register("agents/creator/ContextualGenerate")
class ContextualGenerate(AgentNode):
"""
Generates text based on the given context.
Inputs:
- state: The current state of the graph
- context_type: The type of context to use in generating the text
- context_name: The name of the context to use in generating the text
@@ -150,9 +158,9 @@ class ContextualGenerate(AgentNode):
- partial: The partial text to use in generating the text
- uid: The uid to use in generating the text
- generation_options: The generation options to use in generating the text
Properties:
- context_type: The type of context to use in generating the text
- context_name: The name of the context to use in generating the text
- instructions: The instructions to use in generating the text
@@ -161,82 +169,82 @@ class ContextualGenerate(AgentNode):
- uid: The uid to use in generating the text
- context_aware: Whether to use the context in generating the text
- history_aware: Whether to use the history in generating the text
Outputs:
- state: The updated state of the graph
- text: The generated text
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
context_type = PropertyField(
name="context_type",
description="The type of context to use in generating the text",
type="str",
choices=[
"character attribute",
"character detail",
"character dialogue",
"scene intro",
"scene intent",
"scene phase intent",
"character attribute",
"character detail",
"character dialogue",
"scene intro",
"scene intent",
"scene phase intent",
"scene type description",
"scene type instructions",
"general",
"list",
"scene",
"scene type instructions",
"general",
"list",
"scene",
"world context",
],
default="general"
default="general",
)
context_name = PropertyField(
name="context_name",
description="The name of the context to use in generating the text",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
instructions = PropertyField(
name="instructions",
description="The instructions to use in generating the text",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
length = PropertyField(
name="length",
description="The length of the text to generate",
type="int",
default=100
default=100,
)
character = PropertyField(
name="character",
description="The character to generate the text for",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
uid = PropertyField(
name="uid",
description="The uid to use in generating the text",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
context_aware = PropertyField(
name="context_aware",
description="Whether to use the context in generating the text",
type="bool",
default=True
default=True,
)
history_aware = PropertyField(
name="history_aware",
description="Whether to use the history in generating the text",
type="bool",
default=True
default=True,
)
def __init__(self, title="Contextual Generate", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("context_type", socket_type="str", optional=True)
@@ -247,8 +255,10 @@ class ContextualGenerate(AgentNode):
self.add_input("original", socket_type="str", optional=True)
self.add_input("partial", socket_type="str", optional=True)
self.add_input("uid", socket_type="str", optional=True)
self.add_input("generation_options", socket_type="generation_options", optional=True)
self.add_input(
"generation_options", socket_type="generation_options", optional=True
)
self.set_property("context_type", "general")
self.set_property("context_name", UNRESOLVED)
self.set_property("instructions", UNRESOLVED)
@@ -257,10 +267,10 @@ class ContextualGenerate(AgentNode):
self.set_property("uid", UNRESOLVED)
self.set_property("context_aware", True)
self.set_property("history_aware", True)
self.add_output("state")
self.add_output("text", socket_type="str")
async def run(self, state: GraphState):
scene = active_scene.get()
context_type = self.require_input("context_type")
@@ -274,12 +284,12 @@ class ContextualGenerate(AgentNode):
generation_options = self.normalized_input_value("generation_options")
context_aware = self.normalized_input_value("context_aware")
history_aware = self.normalized_input_value("history_aware")
context = f"{context_type}:{context_name}" if context_name else context_type
if isinstance(character, scene.Character):
character = character.name
text = await self.agent.contextual_generate_from_args(
context=context,
instructions=instructions,
@@ -288,31 +298,32 @@ class ContextualGenerate(AgentNode):
original=original,
partial=partial or "",
uid=uid,
writing_style=generation_options.writing_style if generation_options else None,
writing_style=generation_options.writing_style
if generation_options
else None,
spices=generation_options.spices if generation_options else None,
spice_level=generation_options.spice_level if generation_options else 0.0,
context_aware=context_aware,
history_aware=history_aware
history_aware=history_aware,
)
self.set_output_values({
"state": state,
"text": text
})
self.set_output_values({"state": state, "text": text})
@register("agents/creator/GenerateThematicList")
class GenerateThematicList(AgentNode):
"""
Generates a list of thematic items based on the instructions.
"""
_agent_name:ClassVar[str] = "creator"
_agent_name: ClassVar[str] = "creator"
class Fields:
instructions = PropertyField(
name="instructions",
description="The instructions to use in generating the list",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
iterations = PropertyField(
name="iterations",
@@ -323,27 +334,24 @@ class GenerateThematicList(AgentNode):
min=1,
max=10,
)
def __init__(self, title="Generate Thematic List", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("instructions", socket_type="str", optional=True)
self.set_property("instructions", UNRESOLVED)
self.set_property("iterations", 1)
self.add_output("state")
self.add_output("list", socket_type="list")
async def run(self, state: GraphState):
instructions = self.normalized_input_value("instructions")
iterations = self.require_number_input("iterations")
list = await self.agent.generate_thematic_list(instructions, iterations)
self.set_output_values({
"state": state,
"list": list
})
self.set_output_values({"state": state, "list": list})

View File

@@ -10,7 +10,7 @@ class ScenarioCreatorMixin:
@set_processing
async def determine_scenario_description(self, text: str):
description = await Prompt.request(
f"creator.determine-scenario-description",
"creator.determine-scenario-description",
self.client,
"analyze_long",
vars={
@@ -25,7 +25,7 @@ class ScenarioCreatorMixin:
description: str,
):
content_context = await Prompt.request(
f"creator.determine-content-context",
"creator.determine-content-context",
self.client,
"create_short",
vars={

View File

@@ -1,16 +1,12 @@
from __future__ import annotations
import random
from typing import TYPE_CHECKING, List
import structlog
import traceback
import talemate.emit.async_signals
import talemate.instance as instance
from talemate.agents.conversation import ConversationAgentEmission
from talemate.emit import emit
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage
from talemate.util import random_color
from talemate.character import deactivate_character
@@ -27,7 +23,7 @@ from .legacy_scene_instructions import LegacySceneInstructionsMixin
from .auto_direct import AutoDirectMixin
from .websocket_handler import DirectorWebsocketHandler
import talemate.agents.director.nodes
import talemate.agents.director.nodes # noqa: F401
if TYPE_CHECKING:
from talemate import Character, Scene
@@ -42,7 +38,7 @@ class DirectorAgent(
GenerateChoicesMixin,
AutoDirectMixin,
LegacySceneInstructionsMixin,
Agent
Agent,
):
agent_type = "director"
verbose_name = "Director"
@@ -74,7 +70,7 @@ class DirectorAgent(
],
),
},
),
),
}
MemoryRAGMixin.add_actions(actions)
GenerateChoicesMixin.add_actions(actions)
@@ -112,7 +108,6 @@ class DirectorAgent(
created_characters = []
for character_name in self.scene.world_state.characters.keys():
if exclude and character_name.lower() in exclude:
continue
@@ -148,13 +143,12 @@ class DirectorAgent(
memory = instance.get_agent("memory")
scene: "Scene" = self.scene
any_attribute_templates = False
loading_status = LoadingStatus(max_steps=None, cancellable=True)
# Start of character creation
log.debug("persist_character", name=name)
# Determine the character's name (or clarify if it's already set)
if determine_name:
loading_status("Determining character name")
@@ -162,7 +156,7 @@ class DirectorAgent(
log.debug("persist_character", adjusted_name=name)
# Create the blank character
character:Character = self.scene.Character(name=name)
character: Character = self.scene.Character(name=name)
# Add the character to the scene
character.color = random_color()
@@ -170,32 +164,44 @@ class DirectorAgent(
character=character, agent=instance.get_agent("conversation")
)
await self.scene.add_actor(actor)
try:
try:
# Apply any character generation templates
if templates:
loading_status("Applying character generation templates")
templates = scene.world_state_manager.template_collection.collect_all(templates)
templates = scene.world_state_manager.template_collection.collect_all(
templates
)
log.debug("persist_character", applying_templates=templates)
await scene.world_state_manager.apply_templates(
templates.values(),
templates.values(),
character_name=character.name,
information=content
information=content,
)
# if any of the templates are attribute templates, then we no longer need to
# generate a character sheet
any_attribute_templates = any(template.template_type == "character_attribute" for template in templates.values())
log.debug("persist_character", any_attribute_templates=any_attribute_templates)
if any_attribute_templates and augment_attributes and generate_attributes:
log.debug("persist_character", augmenting_attributes=augment_attributes)
any_attribute_templates = any(
template.template_type == "character_attribute"
for template in templates.values()
)
log.debug(
"persist_character", any_attribute_templates=any_attribute_templates
)
if (
any_attribute_templates
and augment_attributes
and generate_attributes
):
log.debug(
"persist_character", augmenting_attributes=augment_attributes
)
loading_status("Augmenting character attributes")
additional_attributes = await world_state.extract_character_sheet(
name=name,
text=content,
augmentation_instructions=augment_attributes
augmentation_instructions=augment_attributes,
)
character.base_attributes.update(additional_attributes)
@@ -212,25 +218,26 @@ class DirectorAgent(
log.debug("persist_character", attributes=attributes)
character.base_attributes = attributes
# Generate a description for the character
if not description:
loading_status("Generating character description")
description = await creator.determine_character_description(character, information=content)
description = await creator.determine_character_description(
character, information=content
)
character.description = description
log.debug("persist_character", description=description)
# Generate a dialogue instructions for the character
loading_status("Generating acting instructions")
dialogue_instructions = await creator.determine_character_dialogue_instructions(
character,
information=content
dialogue_instructions = (
await creator.determine_character_dialogue_instructions(
character, information=content
)
)
character.dialogue_instructions = dialogue_instructions
log.debug(
"persist_character", dialogue_instructions=dialogue_instructions
)
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
# Narrate the character's entry if the option is selected
if active and narrate_entry:
loading_status("Narrating character entry")
@@ -240,24 +247,26 @@ class DirectorAgent(
"narrate_character_entry",
emit_message=True,
character=character,
narrative_direction=narrate_entry_direction
narrative_direction=narrate_entry_direction,
)
# Deactivate the character if not active
if not active:
await deactivate_character(scene, character)
# Commit the character's details to long term memory
await character.commit_to_memory(memory)
self.scene.emit_status()
self.scene.world_state.emit()
loading_status.done(message=f"{character.name} added to scene", status="success")
loading_status.done(
message=f"{character.name} added to scene", status="success"
)
return character
except GenerationCancelled:
loading_status.done(message="Character creation cancelled", status="idle")
await scene.remove_actor(actor)
except Exception as e:
except Exception:
loading_status.done(message="Character creation failed", status="error")
await scene.remove_actor(actor)
log.error("Error persisting character", error=traceback.format_exc())
@@ -276,4 +285,4 @@ class DirectorAgent(
def allow_repetition_break(
self, kind: str, agent_function_name: str, auto: bool = False
):
return False
return False

View File

@@ -1,23 +1,22 @@
from typing import TYPE_CHECKING
import structlog
import pydantic
from talemate.agents.base import (
set_processing,
AgentAction,
AgentActionConfig,
AgentEmission,
AgentTemplateEmission,
)
from talemate.status import set_loading
import talemate.game.focal as focal
from talemate.prompts import Prompt
import talemate.emit.async_signals as async_signals
from talemate.scene_message import CharacterMessage, TimePassageMessage, DirectorMessage, NarratorMessage
from talemate.scene.schema import ScenePhase, SceneType, SceneIntent
from talemate.scene_message import (
CharacterMessage,
TimePassageMessage,
NarratorMessage,
)
from talemate.scene.schema import ScenePhase, SceneType
from talemate.scene.intent import set_scene_phase
from talemate.world_state.manager import WorldStateManager
from talemate.world_state.templates.scene import SceneType as TemplateSceneType
import talemate.agents.director.auto_direct_nodes
import talemate.agents.director.auto_direct_nodes # noqa: F401
from talemate.world_state.templates.base import TypedCollection
if TYPE_CHECKING:
@@ -26,15 +25,15 @@ if TYPE_CHECKING:
log = structlog.get_logger("talemate.agents.conversation.direct")
#talemate.emit.async_signals.register(
#)
# talemate.emit.async_signals.register(
# )
class AutoDirectMixin:
"""
Director agent mixin for automatic scene direction.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["auto_direct"] = AgentAction(
@@ -108,113 +107,117 @@ class AutoDirectMixin:
),
},
)
# config property helpers
@property
def auto_direct_enabled(self) -> bool:
return self.actions["auto_direct"].enabled
@property
def auto_direct_max_auto_turns(self) -> int:
return self.actions["auto_direct"].config["max_auto_turns"].value
@property
def auto_direct_max_idle_turns(self) -> int:
return self.actions["auto_direct"].config["max_idle_turns"].value
@property
def auto_direct_max_repeat_turns(self) -> int:
return self.actions["auto_direct"].config["max_repeat_turns"].value
@property
def auto_direct_instruct_actors(self) -> bool:
return self.actions["auto_direct"].config["instruct_actors"].value
@property
def auto_direct_instruct_narrator(self) -> bool:
return self.actions["auto_direct"].config["instruct_narrator"].value
@property
def auto_direct_instruct_frequency(self) -> int:
return self.actions["auto_direct"].config["instruct_frequency"].value
@property
def auto_direct_evaluate_scene_intention(self) -> int:
return self.actions["auto_direct"].config["evaluate_scene_intention"].value
@property
def auto_direct_instruct_any(self) -> bool:
"""
Will check whether actor or narrator instructions are enabled.
For narrator instructions to be enabled instruct_narrator needs to be enabled as well.
Returns:
bool: True if either actor or narrator instructions are enabled.
"""
return self.auto_direct_instruct_actors or self.auto_direct_instruct_narrator
# signal connect
def connect(self, scene):
super().connect(scene)
async_signals.get("game_loop").connect(self.on_game_loop)
async def on_game_loop(self, event):
if not self.auto_direct_enabled:
return
if self.auto_direct_evaluate_scene_intention > 0:
evaluation_due = self.get_scene_state("evaluated_scene_intention", 0)
if evaluation_due == 0:
await self.auto_direct_set_scene_intent()
self.set_scene_states(evaluated_scene_intention=self.auto_direct_evaluate_scene_intention)
self.set_scene_states(
evaluated_scene_intention=self.auto_direct_evaluate_scene_intention
)
else:
self.set_scene_states(evaluated_scene_intention=evaluation_due - 1)
# helpers
def auto_direct_is_due_for_instruction(self, actor_name:str) -> bool:
def auto_direct_is_due_for_instruction(self, actor_name: str) -> bool:
"""
Check if the actor is due for instruction.
"""
if self.auto_direct_instruct_frequency == 1:
return True
messages_since_last_instruction = 0
def count_messages(message):
nonlocal messages_since_last_instruction
if message.typ in ["character", "narrator"]:
messages_since_last_instruction += 1
last_instruction = self.scene.last_message_of_type(
"director",
character_name=actor_name,
max_iterations=25,
on_iterate=count_messages,
)
log.debug("auto_direct_is_due_for_instruction", messages_since_last_instruction=messages_since_last_instruction, last_instruction=last_instruction.id if last_instruction else None)
log.debug(
"auto_direct_is_due_for_instruction",
messages_since_last_instruction=messages_since_last_instruction,
last_instruction=last_instruction.id if last_instruction else None,
)
if not last_instruction:
return True
return messages_since_last_instruction >= self.auto_direct_instruct_frequency
def auto_direct_candidates(self) -> list["Character"]:
"""
Returns a list of characters who are valid candidates to speak next.
based on the max_idle_turns, max_repeat_turns, and the most recent character.
"""
scene:"Scene" = self.scene
scene: "Scene" = self.scene
most_recent_character = None
repeat_count = 0
last_player_turn = None
@@ -223,89 +226,105 @@ class AutoDirectMixin:
active_charcters = list(scene.characters)
active_character_names = [character.name for character in active_charcters]
instruct_narrator = self.auto_direct_instruct_narrator
# if there is only one character then they are the only candidate
if len(active_charcters) == 1:
return active_charcters
BACKLOG_LIMIT = 50
player_character_active = scene.player_character_exists
# check the last BACKLOG_LIMIT entries in the scene history and collect into
# a dictionary of character names and the number of turns since they last spoke.
len_history = len(scene.history) - 1
num = 0
for idx in range(len_history, -1, -1):
message = scene.history[idx]
turns = len_history - idx
num += 1
if num > BACKLOG_LIMIT:
break
if isinstance(message, TimePassageMessage):
break
if not isinstance(message, (CharacterMessage, NarratorMessage)):
continue
# if character message but character is not in the active characters list then skip
if isinstance(message, CharacterMessage) and message.character_name not in active_character_names:
if (
isinstance(message, CharacterMessage)
and message.character_name not in active_character_names
):
continue
if isinstance(message, NarratorMessage):
if not instruct_narrator:
continue
character = scene.narrator_character_object
else:
character = scene.get_character(message.character_name)
if not character:
continue
if character.is_player and last_player_turn is None:
last_player_turn = turns
elif not character.is_player and last_player_turn is None:
consecutive_auto_turns += 1
if not most_recent_character:
most_recent_character = character
repeat_count += 1
elif character == most_recent_character:
repeat_count += 1
if character.name not in candidates:
candidates[character.name] = turns
# add any characters that have not spoken yet
for character in active_charcters:
if character.name not in candidates:
candidates[character.name] = 0
# explicitly add narrator if enabled and not already in candidates
if instruct_narrator and scene.narrator_character_object:
narrator = scene.narrator_character_object
if narrator.name not in candidates:
candidates[narrator.name] = 0
log.debug(f"auto_direct_candidates: {candidates}", most_recent_character=most_recent_character, repeat_count=repeat_count, last_player_turn=last_player_turn, consecutive_auto_turns=consecutive_auto_turns)
log.debug(
f"auto_direct_candidates: {candidates}",
most_recent_character=most_recent_character,
repeat_count=repeat_count,
last_player_turn=last_player_turn,
consecutive_auto_turns=consecutive_auto_turns,
)
if not most_recent_character:
log.debug("auto_direct_candidates: No most recent character found.")
return list(scene.characters)
# if player has not spoken in a while then they are favored
if player_character_active and consecutive_auto_turns >= self.auto_direct_max_auto_turns:
log.debug("auto_direct_candidates: User controlled character has not spoken in a while.")
if (
player_character_active
and consecutive_auto_turns >= self.auto_direct_max_auto_turns
):
log.debug(
"auto_direct_candidates: User controlled character has not spoken in a while."
)
return [scene.get_player_character()]
# check if most recent character has spoken too many times in a row
# if so then remove from candidates
if repeat_count >= self.auto_direct_max_repeat_turns:
log.debug("auto_direct_candidates: Most recent character has spoken too many times in a row.", most_recent_character=most_recent_character
log.debug(
"auto_direct_candidates: Most recent character has spoken too many times in a row.",
most_recent_character=most_recent_character,
)
candidates.pop(most_recent_character.name, None)
@@ -314,27 +333,34 @@ class AutoDirectMixin:
favored_candidates = []
for name, turns in candidates.items():
if turns > self.auto_direct_max_idle_turns:
log.debug("auto_direct_candidates: Character has gone too long without speaking.", character_name=name, turns=turns)
log.debug(
"auto_direct_candidates: Character has gone too long without speaking.",
character_name=name,
turns=turns,
)
favored_candidates.append(scene.get_character(name))
if favored_candidates:
return favored_candidates
return [scene.get_character(character_name) for character_name in candidates.keys()]
return [
scene.get_character(character_name) for character_name in candidates.keys()
]
# actions
@set_processing
async def auto_direct_set_scene_intent(self, require:bool=False) -> ScenePhase | None:
async def set_scene_intention(type:str, intention:str) -> ScenePhase:
async def auto_direct_set_scene_intent(
self, require: bool = False
) -> ScenePhase | None:
async def set_scene_intention(type: str, intention: str) -> ScenePhase:
await set_scene_phase(self.scene, type, intention)
self.scene.emit_status()
return self.scene.intent_state.phase
async def do_nothing(*args, **kwargs) -> None:
return None
focal_handler = focal.Focal(
self.client,
callbacks=[
@@ -355,57 +381,63 @@ class AutoDirectMixin:
],
max_calls=1,
scene=self.scene,
scene_type_ids=", ".join([f'"{scene_type.id}"' for scene_type in self.scene.intent_state.scene_types.values()]),
scene_type_ids=", ".join(
[
f'"{scene_type.id}"'
for scene_type in self.scene.intent_state.scene_types.values()
]
),
retries=1,
require=require,
)
await focal_handler.request(
"director.direct-determine-scene-intent",
)
return self.scene.intent_state.phase
@set_processing
async def auto_direct_generate_scene_types(
self,
instructions:str,
max_scene_types:int=1,
self,
instructions: str,
max_scene_types: int = 1,
):
world_state_manager:WorldStateManager = self.scene.world_state_manager
scene_type_templates:TypedCollection = await world_state_manager.get_templates(types=["scene_type"])
async def add_from_template(id:str) -> SceneType:
template:TemplateSceneType | None = scene_type_templates.find_by_name(id)
world_state_manager: WorldStateManager = self.scene.world_state_manager
scene_type_templates: TypedCollection = await world_state_manager.get_templates(
types=["scene_type"]
)
async def add_from_template(id: str) -> SceneType:
template: TemplateSceneType | None = scene_type_templates.find_by_name(id)
if not template:
log.warning("auto_direct_generate_scene_types: Template not found.", name=id)
log.warning(
"auto_direct_generate_scene_types: Template not found.", name=id
)
return None
return template.apply_to_scene(self.scene)
async def generate_scene_type(
id:str = None,
name:str = None,
description:str = None,
instructions:str = None,
id: str = None,
name: str = None,
description: str = None,
instructions: str = None,
) -> SceneType:
if not id or not name:
return None
scene_type = SceneType(
id=id,
name=name,
description=description,
instructions=instructions,
)
self.scene.intent_state.scene_types[id] = scene_type
return scene_type
focal_handler = focal.Focal(
self.client,
callbacks=[
@@ -435,7 +467,7 @@ class AutoDirectMixin:
instructions=instructions,
scene_type_templates=scene_type_templates.templates,
)
await focal_handler.request(
"director.generate-scene-types",
)
)

View File

@@ -1,17 +1,13 @@
import structlog
from typing import TYPE_CHECKING, ClassVar
from typing import ClassVar
from talemate.game.engine.nodes.core import GraphState, PropertyField
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentNode
from talemate.scene.schema import ScenePhase
from talemate.context import active_scene
if TYPE_CHECKING:
from talemate.tale_mate import Scene
from talemate.agents.director import DirectorAgent
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
@register("agents/director/auto-direct/Candidates")
class AutoDirectCandidates(AgentNode):
"""
@@ -19,52 +15,51 @@ class AutoDirectCandidates(AgentNode):
next action, based on the director's auto-direct settings and
the recent scene history.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Auto Direct Candidates", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_output("characters", socket_type="list")
async def run(self, state: GraphState):
candidates = self.agent.auto_direct_candidates()
self.set_output_values({
"characters": candidates
})
self.set_output_values({"characters": candidates})
@register("agents/director/auto-direct/DetermineSceneIntent")
class DetermineSceneIntent(AgentNode):
"""
Determines the scene intent based on the current scene state.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Determine Scene Intent", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_output("state")
self.add_output("scene_phase", socket_type="scene_intent/scene_phase")
async def run(self, state: GraphState):
phase:ScenePhase = await self.agent.auto_direct_set_scene_intent()
self.set_output_values({
"state": state,
"scene_phase": phase
})
phase: ScenePhase = await self.agent.auto_direct_set_scene_intent()
self.set_output_values({"state": state, "scene_phase": phase})
@register("agents/director/auto-direct/GenerateSceneTypes")
class GenerateSceneTypes(AgentNode):
"""
Generates scene types based on the current scene state.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
class Fields:
instructions = PropertyField(
name="instructions",
@@ -72,17 +67,17 @@ class GenerateSceneTypes(AgentNode):
description="The instructions for the scene types",
default="",
)
max_scene_types = PropertyField(
name="max_scene_types",
type="int",
description="The maximum number of scene types to generate",
default=1,
)
def __init__(self, title="Generate Scene Types", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("instructions", socket_type="str", optional=True)
@@ -90,29 +85,26 @@ class GenerateSceneTypes(AgentNode):
self.set_property("instructions", "")
self.set_property("max_scene_types", 1)
self.add_output("state")
async def run(self, state: GraphState):
instructions = self.normalized_input_value("instructions")
max_scene_types = self.normalized_input_value("max_scene_types")
scene_types = await self.agent.auto_direct_generate_scene_types(
instructions=instructions,
max_scene_types=max_scene_types
instructions=instructions, max_scene_types=max_scene_types
)
self.set_output_values({
"state": state,
"scene_types": scene_types
})
self.set_output_values({"state": state, "scene_types": scene_types})
@register("agents/director/auto-direct/IsDueForInstruction")
class IsDueForInstruction(AgentNode):
"""
Checks if the actor is due for instruction based on the auto-direct settings.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
class Fields:
actor_name = PropertyField(
name="actor_name",
@@ -120,24 +112,21 @@ class IsDueForInstruction(AgentNode):
description="The name of the actor to check instruction timing for",
default="",
)
def __init__(self, title="Is Due For Instruction", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("actor_name", socket_type="str")
self.set_property("actor_name", "")
self.add_output("is_due", socket_type="bool")
self.add_output("actor_name", socket_type="str")
async def run(self, state: GraphState):
actor_name = self.require_input("actor_name")
is_due = self.agent.auto_direct_is_due_for_instruction(actor_name)
self.set_output_values({
"is_due": is_due,
"actor_name": actor_name
})
self.set_output_values({"is_due": is_due, "actor_name": actor_name})

View File

@@ -1,14 +1,12 @@
from typing import TYPE_CHECKING
import random
import structlog
from functools import wraps
import dataclasses
from talemate.agents.base import (
set_processing,
AgentAction,
AgentActionConfig,
AgentTemplateEmission,
DynamicInstruction,
)
from talemate.events import GameLoopStartEvent
from talemate.scene_message import NarratorMessage, CharacterMessage
@@ -24,7 +22,7 @@ __all__ = [
log = structlog.get_logger()
talemate.emit.async_signals.register(
"agent.director.generate_choices.before_generate",
"agent.director.generate_choices.before_generate",
"agent.director.generate_choices.inject_instructions",
"agent.director.generate_choices.generated",
)
@@ -32,18 +30,19 @@ talemate.emit.async_signals.register(
if TYPE_CHECKING:
from talemate.tale_mate import Character
@dataclasses.dataclass
class GenerateChoicesEmission(AgentTemplateEmission):
character: "Character | None" = None
choices: list[str] = dataclasses.field(default_factory=list)
class GenerateChoicesMixin:
"""
Director agent mixin that provides functionality for automatically guiding
the actors or the narrator during the scene progression.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["_generate_choices"] = AgentAction(
@@ -65,7 +64,6 @@ class GenerateChoicesMixin:
max=1,
step=0.1,
),
"num_choices": AgentActionConfig(
type="number",
label="Number of Actions",
@@ -75,33 +73,31 @@ class GenerateChoicesMixin:
max=10,
step=1,
),
"never_auto_progress": AgentActionConfig(
type="bool",
label="Never Auto Progress on Action Selection",
description="If enabled, the scene will not auto progress after you select an action.",
value=False,
),
"instructions": AgentActionConfig(
type="blob",
label="Instructions",
description="Provide some instructions to the director for generating actions.",
value="",
),
}
},
)
# config property helpers
@property
def generate_choices_enabled(self):
return self.actions["_generate_choices"].enabled
@property
def generate_choices_chance(self):
return self.actions["_generate_choices"].config["chance"].value
@property
def generate_choices_num_choices(self):
return self.actions["_generate_choices"].config["num_choices"].value
@@ -115,24 +111,25 @@ class GenerateChoicesMixin:
return self.actions["_generate_choices"].config["instructions"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("player_turn_start").connect(self.on_player_turn_start)
talemate.emit.async_signals.get("player_turn_start").connect(
self.on_player_turn_start
)
async def on_player_turn_start(self, event: GameLoopStartEvent):
if not self.enabled:
return
if self.generate_choices_enabled:
# look backwards through history and abort if we encounter
# a character message with source "player" before either
# a character message with a different source or a narrator message
#
# this is so choices aren't generated when the player message was
# the most recent content in the scene
for i in range(len(self.scene.history) - 1, -1, -1):
message = self.scene.history[i]
if isinstance(message, NarratorMessage):
@@ -141,12 +138,11 @@ class GenerateChoicesMixin:
if message.source == "player":
return
break
if random.random() < self.generate_choices_chance:
await self.generate_choices()
await self.generate_choices()
# methods
@set_processing
async def generate_choices(
@@ -154,20 +150,23 @@ class GenerateChoicesMixin:
instructions: str = None,
character: "Character | str | None" = None,
):
emission: GenerateChoicesEmission = GenerateChoicesEmission(agent=self)
if isinstance(character, str):
character = self.scene.get_character(character)
if not character:
character = self.scene.get_player_character()
emission.character = character
await talemate.emit.async_signals.get("agent.director.generate_choices.before_generate").send(emission)
await talemate.emit.async_signals.get("agent.director.generate_choices.inject_instructions").send(emission)
await talemate.emit.async_signals.get(
"agent.director.generate_choices.before_generate"
).send(emission)
await talemate.emit.async_signals.get(
"agent.director.generate_choices.inject_instructions"
).send(emission)
response = await Prompt.request(
"director.generate-choices",
self.client,
@@ -178,7 +177,9 @@ class GenerateChoicesMixin:
"character": character,
"num_choices": self.generate_choices_num_choices,
"instructions": instructions or self.generate_choices_instructions,
"dynamic_instructions": emission.dynamic_instructions if emission else None,
"dynamic_instructions": emission.dynamic_instructions
if emission
else None,
},
)
@@ -187,10 +188,10 @@ class GenerateChoicesMixin:
choices = util.extract_list(choice_text)
# strip quotes
choices = [choice.strip().strip('"') for choice in choices]
# limit to num_choices
choices = choices[:self.generate_choices_num_choices]
choices = choices[: self.generate_choices_num_choices]
except Exception as e:
log.error("generate_choices failed", error=str(e), response=response)
return
@@ -198,15 +199,17 @@ class GenerateChoicesMixin:
emit(
"player_choice",
response,
data = {
data={
"choices": choices,
"character": character.name,
},
websocket_passthrough=True
websocket_passthrough=True,
)
emission.response = response
emission.choices = choices
await talemate.emit.async_signals.get("agent.director.generate_choices.generated").send(emission)
return emission.response
await talemate.emit.async_signals.get(
"agent.director.generate_choices.generated"
).send(emission)
return emission.response

View File

@@ -17,12 +17,11 @@ from talemate.util import strip_partial_sentences
if TYPE_CHECKING:
from talemate.tale_mate import Character
from talemate.agents.summarize.analyze_scene import SceneAnalysisEmission
from talemate.agents.editor.revision import RevisionAnalysisEmission
log = structlog.get_logger()
talemate.emit.async_signals.register(
"agent.director.guide.before_generate",
"agent.director.guide.before_generate",
"agent.director.guide.inject_instructions",
"agent.director.guide.generated",
)
@@ -32,6 +31,7 @@ talemate.emit.async_signals.register(
class DirectorGuidanceEmission(AgentTemplateEmission):
pass
def set_processing(fn):
"""
Custom decorator that emits the agent status as processing while the function
@@ -42,29 +42,33 @@ def set_processing(fn):
@wraps(fn)
async def wrapper(self, *args, **kwargs):
emission: DirectorGuidanceEmission = DirectorGuidanceEmission(agent=self)
await talemate.emit.async_signals.get("agent.director.guide.before_generate").send(emission)
await talemate.emit.async_signals.get("agent.director.guide.inject_instructions").send(emission)
await talemate.emit.async_signals.get(
"agent.director.guide.before_generate"
).send(emission)
await talemate.emit.async_signals.get(
"agent.director.guide.inject_instructions"
).send(emission)
agent_context = active_agent.get()
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
response = await fn(self, *args, **kwargs)
emission.response = response
await talemate.emit.async_signals.get("agent.director.guide.generated").send(emission)
await talemate.emit.async_signals.get("agent.director.guide.generated").send(
emission
)
return emission.response
return wrapper
class GuideSceneMixin:
"""
Director agent mixin that provides functionality for automatically guiding
the actors or the narrator during the scene progression.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["guide_scene"] = AgentAction(
@@ -81,13 +85,13 @@ class GuideSceneMixin:
type="bool",
label="Guide actors",
description="Guide the actors in the scene. This happens during every actor turn.",
value=True
value=True,
),
"guide_narrator": AgentActionConfig(
type="bool",
label="Guide narrator",
description="Guide the narrator during the scene. This happens during the narrator's turn.",
value=True
value=True,
),
"guidance_length": AgentActionConfig(
type="text",
@@ -101,7 +105,7 @@ class GuideSceneMixin:
{"label": "Medium (512)", "value": "512"},
{"label": "Medium Long (768)", "value": "768"},
{"label": "Long (1024)", "value": "1024"},
]
],
),
"cache_guidance": AgentActionConfig(
type="bool",
@@ -109,57 +113,57 @@ class GuideSceneMixin:
description="Will not regenerate the guidance until the scene moves forward or the analysis changes.",
value=False,
quick_toggle=True,
)
}
),
},
)
# config property helpers
@property
def guide_scene(self) -> bool:
return self.actions["guide_scene"].enabled
@property
def guide_actors(self) -> bool:
return self.actions["guide_scene"].config["guide_actors"].value
@property
def guide_narrator(self) -> bool:
return self.actions["guide_scene"].config["guide_narrator"].value
@property
def guide_scene_guidance_length(self) -> int:
return int(self.actions["guide_scene"].config["guidance_length"].value)
@property
def guide_scene_cache_guidance(self) -> bool:
return self.actions["guide_scene"].config["cache_guidance"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").connect(
self.on_summarization_scene_analysis_after
)
talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").connect(
self.on_summarization_scene_analysis_after
)
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
self.on_editor_revision_analysis_before
)
async def on_summarization_scene_analysis_after(self, emission: "SceneAnalysisEmission"):
talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.after"
).connect(self.on_summarization_scene_analysis_after)
talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.cached"
).connect(self.on_summarization_scene_analysis_after)
talemate.emit.async_signals.get(
"agent.editor.revision-analysis.before"
).connect(self.on_editor_revision_analysis_before)
async def on_summarization_scene_analysis_after(
self, emission: "SceneAnalysisEmission"
):
if not self.guide_scene:
return
guidance = None
cached_guidance = await self.get_cached_guidance(emission.response)
if emission.analysis_type == "narration" and self.guide_narrator:
if cached_guidance:
guidance = cached_guidance
else:
@@ -167,15 +171,14 @@ class GuideSceneMixin:
emission.response,
response_length=self.guide_scene_guidance_length,
)
if not guidance:
log.warning("director.guide_scene.narration: Empty resonse")
return
self.set_context_states(narrator_guidance=guidance)
elif emission.analysis_type == "conversation" and self.guide_actors:
elif emission.analysis_type == "conversation" and self.guide_actors:
if cached_guidance:
guidance = cached_guidance
else:
@@ -184,94 +187,110 @@ class GuideSceneMixin:
emission.template_vars.get("character"),
response_length=self.guide_scene_guidance_length,
)
if not guidance:
log.warning("director.guide_scene.conversation: Empty resonse")
return
self.set_context_states(actor_guidance=guidance)
if guidance:
await self.set_cached_guidance(
emission.response,
guidance,
emission.analysis_type,
emission.template_vars.get("character")
emission.response,
guidance,
emission.analysis_type,
emission.template_vars.get("character"),
)
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
cached_guidance = await self.get_cached_guidance(emission.response)
if cached_guidance:
emission.dynamic_instructions.append(DynamicInstruction(
title="Guidance",
content=cached_guidance
))
emission.dynamic_instructions.append(
DynamicInstruction(title="Guidance", content=cached_guidance)
)
# helpers
def _cache_key(self) -> str:
return f"cached_guidance"
async def get_cached_guidance(self, analysis:str | None = None) -> str | None:
return "cached_guidance"
async def get_cached_guidance(self, analysis: str | None = None) -> str | None:
"""
Returns the cached guidance for the given analysis.
If analysis is not provided, it will return the cached guidance for the last analysis regardless
of the fingerprint.
"""
if not self.guide_scene_cache_guidance:
return None
key = self._cache_key()
cached_guidance = self.get_scene_state(key)
if cached_guidance:
if not analysis:
return cached_guidance.get("guidance")
elif cached_guidance.get("fp") == self.context_fingerpint(extra=[analysis]):
return cached_guidance.get("guidance")
return None
async def set_cached_guidance(self, analysis:str, guidance: str, analysis_type: str, character: "Character | None" = None):
async def set_cached_guidance(
self,
analysis: str,
guidance: str,
analysis_type: str,
character: "Character | None" = None,
):
"""
Sets the cached guidance for the given analysis.
"""
key = self._cache_key()
self.set_scene_states(**{
key: {
"fp": self.context_fingerpint(extra=[analysis]),
"guidance": guidance,
"analysis_type": analysis_type,
"character": character.name if character else None,
self.set_scene_states(
**{
key: {
"fp": self.context_fingerpint(extra=[analysis]),
"guidance": guidance,
"analysis_type": analysis_type,
"character": character.name if character else None,
}
}
})
)
async def get_cached_character_guidance(self, character_name: str) -> str | None:
"""
Returns the cached guidance for the given character.
"""
key = self._cache_key()
cached_guidance = self.get_scene_state(key)
if not cached_guidance:
return None
if cached_guidance.get("character") == character_name and cached_guidance.get("analysis_type") == "conversation":
if (
cached_guidance.get("character") == character_name
and cached_guidance.get("analysis_type") == "conversation"
):
return cached_guidance.get("guidance")
return None
# methods
@set_processing
async def guide_actor_off_of_scene_analysis(self, analysis: str, character: "Character", response_length: int = 256):
async def guide_actor_off_of_scene_analysis(
self, analysis: str, character: "Character", response_length: int = 256
):
"""
Guides the actor based on the scene analysis.
"""
log.debug("director.guide_actor_off_of_scene_analysis", analysis=analysis, character=character)
log.debug(
"director.guide_actor_off_of_scene_analysis",
analysis=analysis,
character=character,
)
response = await Prompt.request(
"director.guide-conversation",
self.client,
@@ -285,19 +304,17 @@ class GuideSceneMixin:
},
)
return strip_partial_sentences(response).strip()
@set_processing
async def guide_narrator_off_of_scene_analysis(
self,
analysis: str,
response_length: int = 256
self, analysis: str, response_length: int = 256
):
"""
Guides the narrator based on the scene analysis.
"""
log.debug("director.guide_narrator_off_of_scene_analysis", analysis=analysis)
response = await Prompt.request(
"director.guide-narration",
self.client,
@@ -309,4 +326,4 @@ class GuideSceneMixin:
"max_tokens": self.client.max_token_length,
},
)
return strip_partial_sentences(response).strip()
return strip_partial_sentences(response).strip()

View File

@@ -1,4 +1,3 @@
from typing import TYPE_CHECKING
import structlog
from talemate.agents.base import (
set_processing,
@@ -9,24 +8,27 @@ from talemate.events import GameLoopActorIterEvent, SceneStateEvent
log = structlog.get_logger("talemate.agents.conversation.legacy_scene_instructions")
class LegacySceneInstructionsMixin(
GameInstructionsMixin,
):
"""
Legacy support for scoped api instructions in scenes.
This is being replaced by node based in structions, but kept for backwards compatibility.
THIS WILL BE DEPRECATED IN THE FUTURE.
"""
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.LSI_on_player_dialog)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
self.LSI_on_player_dialog
)
talemate.emit.async_signals.get("scene_init").connect(self.LSI_on_scene_init)
async def LSI_on_scene_init(self, event: SceneStateEvent):
"""
LEGACY: If game state instructions specify to be run at the start of the game loop
@@ -57,8 +59,10 @@ class LegacySceneInstructionsMixin(
if not event.actor.character.is_player:
return
log.warning(f"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
log.warning(
"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
)
if event.game_loop.had_passive_narration:
log.debug(
@@ -77,17 +81,17 @@ class LegacySceneInstructionsMixin(
not self.scene.npc_character_names
or self.scene.game_state.ops.always_direct
)
log.warning(f"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.", always_direct=always_direct)
log.warning(
"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.",
always_direct=always_direct,
)
next_direct = self.next_direct_scene
TURNS = 5
if (
next_direct % TURNS != 0
or next_direct == 0
):
if next_direct % TURNS != 0 or next_direct == 0:
if not always_direct:
log.info("direct", skip=True, next_direct=next_direct)
self.next_direct_scene += 1
@@ -112,8 +116,10 @@ class LegacySceneInstructionsMixin(
async def LSI_direct_scene(self):
"""
LEGACY: Direct the scene based scoped api scene instructions.
This is being replaced by node based instructions, but kept for
This is being replaced by node based instructions, but kept for
backwards compatibility.
"""
log.warning(f"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
await self.run_scene_instructions(self.scene)
log.warning(
"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
)
await self.run_scene_instructions(self.scene)

View File

@@ -1,29 +1,34 @@
import structlog
from typing import ClassVar, TYPE_CHECKING
from talemate.context import active_scene
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
from typing import ClassVar
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
TYPE_CHOICES,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
if TYPE_CHECKING:
from talemate.tale_mate import Scene
TYPE_CHOICES.extend([
"director/direction",
])
TYPE_CHOICES.extend(
[
"director/direction",
]
)
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
@register("agents/director/Settings")
class DirectorSettings(AgentSettingsNode):
"""
Base node to render director agent settings.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
def __init__(self, title="Director Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/director/PersistCharacter")
class PersistCharacter(AgentNode):
@@ -31,45 +36,44 @@ class PersistCharacter(AgentNode):
Persists a character that currently only exists as part of the given context
as a real character that can actively participate in the scene.
"""
_agent_name:ClassVar[str] = "director"
_agent_name: ClassVar[str] = "director"
class Fields:
determine_name = PropertyField(
name="determine_name",
type="bool",
description="Whether to determine the name of the character",
default=True
default=True,
)
def __init__(self, title="Persist Character", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character_name", socket_type="str")
self.add_input("context", socket_type="str", optional=True)
self.add_input("attributes", socket_type="dict,str", optional=True)
self.set_property("determine_name", True)
self.add_output("state")
self.add_output("character", socket_type="character")
async def run(self, state: GraphState):
character_name = self.get_input_value("character_name")
context = self.normalized_input_value("context")
attributes = self.normalized_input_value("attributes")
determine_name = self.normalized_input_value("determine_name")
character = await self.agent.persist_character(
name=character_name,
content=context,
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()]) if attributes else None,
determine_name=determine_name
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()])
if attributes
else None,
determine_name=determine_name,
)
self.set_output_values({
"state": state,
"character": character
})
self.set_output_values({"state": state, "character": character})

View File

@@ -1,7 +1,6 @@
import pydantic
import asyncio
import structlog
import traceback
from typing import TYPE_CHECKING
from talemate.instance import get_agent
@@ -18,27 +17,31 @@ __all__ = [
log = structlog.get_logger("talemate.server.director")
class InstructionPayload(pydantic.BaseModel):
instructions:str = ""
instructions: str = ""
class SelectChoicePayload(pydantic.BaseModel):
choice: str
character:str = ""
character: str = ""
class CharacterPayload(InstructionPayload):
character:str = ""
character: str = ""
class PersistCharacterPayload(pydantic.BaseModel):
name: str
templates: list[str] | None = None
narrate_entry: bool = True
narrate_entry_direction: str = ""
active: bool = True
determine_name: bool = True
augment_attributes: str = ""
generate_attributes: bool = True
content: str = ""
description: str = ""
@@ -47,13 +50,13 @@ class DirectorWebsocketHandler(Plugin):
"""
Handles director actions
"""
router = "director"
@property
def director(self):
return get_agent("director")
@set_loading("Generating dynamic actions", cancellable=True, as_async=True)
async def handle_request_dynamic_choices(self, data: dict):
"""
@@ -61,21 +64,21 @@ class DirectorWebsocketHandler(Plugin):
"""
payload = CharacterPayload(**data)
await self.director.generate_choices(**payload.model_dump())
async def handle_select_choice(self, data: dict):
payload = SelectChoicePayload(**data)
log.debug("selecting choice", payload=payload)
if payload.character:
character = self.scene.get_character(payload.character)
else:
character = self.scene.get_player_character()
if not character:
log.error("handle_select_choice: could not find character", payload=payload)
return
# hijack the interaction state
try:
interaction_state = interaction.get()
@@ -83,20 +86,23 @@ class DirectorWebsocketHandler(Plugin):
# no interaction state
log.error("handle_select_choice: no interaction state", payload=payload)
return
interaction_state.from_choice = payload.choice
interaction_state.act_as = character.name if not character.is_player else None
interaction_state.input = f"@{payload.choice}"
async def handle_persist_character(self, data: dict):
payload = PersistCharacterPayload(**data)
scene: "Scene" = self.scene
if not payload.content:
payload.content = scene.snapshot(lines=15)
# add as asyncio task
task = asyncio.create_task(self.director.persist_character(**payload.model_dump()))
task = asyncio.create_task(
self.director.persist_character(**payload.model_dump())
)
async def handle_task_done(task):
if task.exception():
log.error("Error persisting character", error=task.exception())
@@ -112,4 +118,3 @@ class DirectorWebsocketHandler(Plugin):
await self.signal_operation_done()
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))

View File

@@ -40,7 +40,7 @@ class EditorAgent(
agent_type = "editor"
verbose_name = "Editor"
websocket_handler = EditorWebsocketHandler
@classmethod
def init_actions(cls) -> dict[str, AgentAction]:
actions = {
@@ -56,9 +56,9 @@ class EditorAgent(
description="The formatting to use for exposition.",
value="novel",
choices=[
{"label": "Chat RP: \"Speech\" *narration*", "value": "chat"},
{"label": "Novel: \"Speech\" narration", "value": "novel"},
]
{"label": 'Chat RP: "Speech" *narration*', "value": "chat"},
{"label": 'Novel: "Speech" narration', "value": "novel"},
],
),
"narrator": AgentActionConfig(
type="bool",
@@ -81,7 +81,7 @@ class EditorAgent(
description="Attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
),
}
MemoryRAGMixin.add_actions(actions)
RevisionMixin.add_actions(actions)
return actions
@@ -90,7 +90,7 @@ class EditorAgent(
self.client = client
self.is_enabled = True
self.actions = EditorAgent.init_actions()
@property
def enabled(self):
return self.is_enabled
@@ -102,11 +102,11 @@ class EditorAgent(
@property
def experimental(self):
return True
@property
def fix_exposition_enabled(self):
return self.actions["fix_exposition"].enabled
@property
def fix_exposition_formatting(self):
return self.actions["fix_exposition"].config["formatting"].value
@@ -114,11 +114,10 @@ class EditorAgent(
@property
def fix_exposition_narrator(self):
return self.actions["fix_exposition"].config["narrator"].value
@property
def fix_exposition_user_input(self):
return self.actions["fix_exposition"].config["user_input"].value
def connect(self, scene):
super().connect(scene)
@@ -134,7 +133,7 @@ class EditorAgent(
formatting = "md"
else:
formatting = None
if self.fix_exposition_formatting == "chat":
text = text.replace("**", "*")
text = text.replace("[", "*").replace("]", "*")
@@ -143,15 +142,14 @@ class EditorAgent(
text = text.replace("*", "")
text = text.replace("[", "").replace("]", "")
text = text.replace("(", "").replace(")", "")
cleaned = util.ensure_dialog_format(
text,
talking_character=character.name if character else None,
formatting=formatting
)
return cleaned
cleaned = util.ensure_dialog_format(
text,
talking_character=character.name if character else None,
formatting=formatting,
)
return cleaned
async def on_conversation_generated(self, emission: ConversationAgentEmission):
"""
@@ -181,11 +179,13 @@ class EditorAgent(
emission.response = edit
@set_processing
async def cleanup_character_message(self, content: str, character: Character, force: bool = False):
async def cleanup_character_message(
self, content: str, character: Character, force: bool = False
):
"""
Edits a text to make sure all narrative exposition and emotes is encased in *
"""
# if not content was generated, return it as is
if not content:
return content
@@ -210,14 +210,14 @@ class EditorAgent(
content = util.clean_dialogue(content, main_name=character.name)
content = util.strip_partial_sentences(content)
# if there are uneven quotation marks, fix them by adding a closing quote
if '"' in content and content.count('"') % 2 != 0:
content += '"'
if not self.fix_exposition_enabled and not exposition_fixed:
return content
content = self.fix_exposition_in_text(content, character)
return content
@@ -225,7 +225,7 @@ class EditorAgent(
@set_processing
async def clean_up_narration(self, content: str, force: bool = False):
content = util.strip_partial_sentences(content)
if (self.fix_exposition_enabled and self.fix_exposition_narrator or force):
if self.fix_exposition_enabled and self.fix_exposition_narrator or force:
content = self.fix_exposition_in_text(content, None)
if self.fix_exposition_formatting == "chat":
if '"' not in content and "*" not in content:
@@ -234,25 +234,27 @@ class EditorAgent(
return content
@set_processing
async def cleanup_user_input(self, text: str, as_narration: bool = False, force: bool = False):
async def cleanup_user_input(
self, text: str, as_narration: bool = False, force: bool = False
):
# special prefix characters - when found, never edit
PREFIX_CHARACTERS = ("!", "@", "/")
if text.startswith(PREFIX_CHARACTERS):
return text
if (not self.fix_exposition_user_input or not self.fix_exposition_enabled) and not force:
if (
not self.fix_exposition_user_input or not self.fix_exposition_enabled
) and not force:
return text
if not as_narration:
if self.fix_exposition_formatting == "chat":
if '"' not in text and "*" not in text:
text = f'"{text}"'
else:
return await self.clean_up_narration(text)
return self.fix_exposition_in_text(text)
@set_processing
async def add_detail(self, content: str, character: Character):
@@ -279,4 +281,4 @@ class EditorAgent(
response = util.clean_dialogue(response, main_name=character.name)
response = util.strip_partial_sentences(response)
return response
return response

View File

@@ -1,70 +1,37 @@
import structlog
from typing import ClassVar, TYPE_CHECKING
from talemate.context import active_scene
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
if TYPE_CHECKING:
from talemate.tale_mate import Scene
from talemate.agents.editor import EditorAgent
log = structlog.get_logger("talemate.game.engine.nodes.agents.editor")
@register("agents/editor/Settings")
class EditorSettings(AgentSettingsNode):
"""
Base node to render editor agent settings.
"""
_agent_name:ClassVar[str] = "editor"
_agent_name: ClassVar[str] = "editor"
def __init__(self, title="Editor Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/editor/CleanUpUserInput")
class CleanUpUserInput(AgentNode):
"""
Cleans up user input.
"""
_agent_name:ClassVar[str] = "editor"
class Fields:
force = PropertyField(
name="force",
description="Force clean up",
type="bool",
default=False,
)
def __init__(self, title="Clean Up User Input", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("user_input", socket_type="str")
self.add_input("as_narration", socket_type="bool")
self.set_property("force", False)
self.add_output("cleaned_user_input", socket_type="str")
async def run(self, state: GraphState):
editor:"EditorAgent" = self.agent
user_input = self.get_input_value("user_input")
force = self.get_property("force")
as_narration = self.get_input_value("as_narration")
cleaned_user_input = await editor.cleanup_user_input(user_input, as_narration=as_narration, force=force)
self.set_output_values({
"cleaned_user_input": cleaned_user_input,
})
@register("agents/editor/CleanUpNarration")
class CleanUpNarration(AgentNode):
"""
Cleans up narration.
"""
_agent_name:ClassVar[str] = "editor"
_agent_name: ClassVar[str] = "editor"
class Fields:
force = PropertyField(
@@ -73,31 +40,41 @@ class CleanUpNarration(AgentNode):
type="bool",
default=False,
)
def __init__(self, title="Clean Up Narration", **kwargs):
def __init__(self, title="Clean Up User Input", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("narration", socket_type="str")
self.add_input("user_input", socket_type="str")
self.add_input("as_narration", socket_type="bool")
self.set_property("force", False)
self.add_output("cleaned_narration", socket_type="str")
self.add_output("cleaned_user_input", socket_type="str")
async def run(self, state: GraphState):
editor:"EditorAgent" = self.agent
narration = self.get_input_value("narration")
editor: "EditorAgent" = self.agent
user_input = self.get_input_value("user_input")
force = self.get_property("force")
cleaned_narration = await editor.cleanup_narration(narration, force=force)
self.set_output_values({
"cleaned_narration": cleaned_narration,
})
@register("agents/editor/CleanUoCharacterMessage")
class CleanUpCharacterMessage(AgentNode):
as_narration = self.get_input_value("as_narration")
cleaned_user_input = await editor.cleanup_user_input(
user_input, as_narration=as_narration, force=force
)
self.set_output_values(
{
"cleaned_user_input": cleaned_user_input,
}
)
@register("agents/editor/CleanUpNarration")
class CleanUpNarration(AgentNode):
"""
Cleans up character message.
Cleans up narration.
"""
_agent_name:ClassVar[str] = "editor"
_agent_name: ClassVar[str] = "editor"
class Fields:
force = PropertyField(
name="force",
@@ -105,22 +82,62 @@ class CleanUpCharacterMessage(AgentNode):
type="bool",
default=False,
)
def __init__(self, title="Clean Up Narration", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("narration", socket_type="str")
self.set_property("force", False)
self.add_output("cleaned_narration", socket_type="str")
async def run(self, state: GraphState):
editor: "EditorAgent" = self.agent
narration = self.get_input_value("narration")
force = self.get_property("force")
cleaned_narration = await editor.cleanup_narration(narration, force=force)
self.set_output_values(
{
"cleaned_narration": cleaned_narration,
}
)
@register("agents/editor/CleanUoCharacterMessage")
class CleanUpCharacterMessage(AgentNode):
"""
Cleans up character message.
"""
_agent_name: ClassVar[str] = "editor"
class Fields:
force = PropertyField(
name="force",
description="Force clean up",
type="bool",
default=False,
)
def __init__(self, title="Clean Up Character Message", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("text", socket_type="str")
self.add_input("character", socket_type="character")
self.set_property("force", False)
self.add_output("cleaned_character_message", socket_type="str")
async def run(self, state: GraphState):
editor:"EditorAgent" = self.agent
editor: "EditorAgent" = self.agent
text = self.get_input_value("text")
force = self.get_property("force")
character = self.get_input_value("character")
cleaned_character_message = await editor.cleanup_character_message(text, character, force=force)
self.set_output_values({
"cleaned_character_message": cleaned_character_message,
})
cleaned_character_message = await editor.cleanup_character_message(
text, character, force=force
)
self.set_output_values(
{
"cleaned_character_message": cleaned_character_message,
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
from talemate.instance import get_agent
from talemate.server.websocket_plugin import Plugin
from talemate.status import set_loading
from talemate.scene_message import CharacterMessage
from talemate.agents.editor.revision import RevisionContext, RevisionInformation
@@ -17,47 +16,49 @@ __all__ = [
log = structlog.get_logger("talemate.server.editor")
class RevisionPayload(pydantic.BaseModel):
message_id: int
class EditorWebsocketHandler(Plugin):
"""
Handles editor actions
"""
router = "editor"
@property
def editor(self):
return get_agent("editor")
async def handle_request_revision(self, data: dict):
"""
Generate clickable actions for the user
"""
editor = self.editor
scene:"Scene" = self.scene
scene: "Scene" = self.scene
if not editor.revision_enabled:
raise Exception("Revision is not enabled")
payload = RevisionPayload(**data)
message = scene.get_message(payload.message_id)
character = None
if isinstance(message, CharacterMessage):
character = scene.get_character(message.character_name)
if not message:
raise Exception("Message not found")
with RevisionContext(message.id):
info = RevisionInformation(
text=message.message,
character=character,
)
revised = await editor.revision_revise(info)
scene.edit_message(message.id, revised)
scene.edit_message(message.id, revised)

View File

@@ -75,10 +75,9 @@ class MemoryAgent(Agent):
@classmethod
def init_actions(cls, presets: list[dict] | None = None) -> dict[str, AgentAction]:
if presets is None:
presets = []
actions = {
"_config": AgentAction(
enabled=True,
@@ -101,23 +100,25 @@ class MemoryAgent(Agent):
choices=[
{"value": "cpu", "label": "CPU"},
{"value": "cuda", "label": "CUDA"},
]
],
),
},
),
}
return actions
def __init__(self, scene, **kwargs):
self.db = None
self.scene = scene
self.memory_tracker = {}
self.config = load_config()
self._ready_to_add = False
handlers["config_saved"].connect(self.on_config_saved)
async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available)
async_signals.get("client.embeddings_available").connect(
self.on_client_embeddings_available
)
self.actions = MemoryAgent.init_actions(presets=self.get_presets)
@property
@@ -132,30 +133,32 @@ class MemoryAgent(Agent):
@property
def db_name(self):
raise NotImplementedError()
@property
def get_presets(self):
def _label(embedding:dict):
prefix = embedding['client'] if embedding['client'] else embedding['embeddings']
if embedding['model']:
def _label(embedding: dict):
prefix = (
embedding["client"] if embedding["client"] else embedding["embeddings"]
)
if embedding["model"]:
return f"{prefix}: {embedding['model']}"
else:
return f"{prefix}"
return [
{"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
{"value": k, "label": _label(v)}
for k, v in self.config.get("presets", {}).get("embeddings", {}).items()
]
@property
def embeddings_config(self):
_embeddings = self.actions["_config"].config["embeddings"].value
return self.config.get("presets", {}).get("embeddings", {}).get(_embeddings, {})
@property
def embeddings(self):
return self.embeddings_config.get("embeddings", "sentence-transformer")
@property
def using_openai_embeddings(self):
return self.embeddings == "openai"
@@ -163,39 +166,34 @@ class MemoryAgent(Agent):
@property
def using_instructor_embeddings(self):
return self.embeddings == "instructor"
@property
def using_sentence_transformer_embeddings(self):
return self.embeddings == "default" or self.embeddings == "sentence-transformer"
@property
def using_client_api_embeddings(self):
return self.embeddings == "client-api"
@property
def using_local_embeddings(self):
return self.embeddings in [
"instructor",
"sentence-transformer",
"default"
]
return self.embeddings in ["instructor", "sentence-transformer", "default"]
@property
def embeddings_client(self):
return self.embeddings_config.get("client")
@property
def max_distance(self) -> float:
distance = float(self.embeddings_config.get("distance", 1.0))
distance_mod = float(self.embeddings_config.get("distance_mod", 1.0))
return distance * distance_mod
@property
def model(self):
return self.embeddings_config.get("model")
@property
def distance_function(self):
return self.embeddings_config.get("distance_function", "l2")
@@ -213,25 +211,28 @@ class MemoryAgent(Agent):
"""
Returns a unique fingerprint for the current configuration
"""
model_name = self.model.replace('/','-') if self.model else "none"
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
model_name = self.model.replace("/", "-") if self.model else "none"
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
async def apply_config(self, *args, **kwargs):
_fingerprint = self.fingerprint
await super().apply_config(*args, **kwargs)
fingerprint_changed = _fingerprint != self.fingerprint
# have embeddings or device changed?
if fingerprint_changed:
log.warning("memory agent", status="embedding function changed", old=_fingerprint, new=self.fingerprint)
log.warning(
"memory agent",
status="embedding function changed",
old=_fingerprint,
new=self.fingerprint,
)
await self.handle_embeddings_change()
@set_processing
async def handle_embeddings_change(self):
scene = active_scene.get()
@@ -239,10 +240,10 @@ class MemoryAgent(Agent):
# if sentence-transformer and no model-name, set embeddings to default
if self.using_sentence_transformer_embeddings and not self.model:
self.actions["_config"].config["embeddings"].value = "default"
if not scene or not scene.get_helper("memory"):
return
self.close_db(scene)
emit("status", "Re-importing context database", status="busy")
await scene.commit_to_memory()
@@ -257,38 +258,49 @@ class MemoryAgent(Agent):
def on_config_saved(self, event):
loop = asyncio.get_running_loop()
openai_key = self.openai_api_key
fingerprint = self.fingerprint
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
self.config = load_config()
new_presets = self.sync_presets()
if fingerprint != self.fingerprint:
log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint)
log.warning(
"memory agent",
status="embedding function changed",
old=fingerprint,
new=self.fingerprint,
)
loop.run_until_complete(self.handle_embeddings_change())
emit_status = False
if openai_key != self.openai_api_key:
emit_status = True
if old_presets != new_presets:
emit_status = True
if emit_status:
loop.run_until_complete(self.emit_status())
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
current_embeddings = self.actions["_config"].config["embeddings"].value
if current_embeddings == event.client.embeddings_identifier:
return
if not self.using_client_api_embeddings or not self.ready:
log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier)
self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier
log.warning(
"memory agent - client embeddings available",
status="changing embeddings",
old=current_embeddings,
new=event.client.embeddings_identifier,
)
self.actions["_config"].config[
"embeddings"
].value = event.client.embeddings_identifier
await self.emit_status()
await self.handle_embeddings_change()
await self.save_config()
@@ -301,13 +313,16 @@ class MemoryAgent(Agent):
except EmbeddingsModelLoadError:
raise
except Exception as e:
log.error("memory agent", error="failed to set db", details=traceback.format_exc())
if "torchvision::nms does not exist" in str(e):
raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed")
raise SetDBError(str(e))
log.error(
"memory agent", error="failed to set db", details=traceback.format_exc()
)
if "torchvision::nms does not exist" in str(e):
raise SetDBError(
"The embeddings you are trying to use require the `torchvision` package to be installed"
)
raise SetDBError(str(e))
def close_db(self):
raise NotImplementedError()
@@ -426,9 +441,9 @@ class MemoryAgent(Agent):
with MemoryRequest(query=text, query_params=query) as active_memory_request:
active_memory_request.max_distance = self.max_distance
return await asyncio.to_thread(self._get, text, character, **query)
#return await loop.run_in_executor(
# return await loop.run_in_executor(
# None, functools.partial(self._get, text, character, **query)
#)
# )
def _get(self, text, character=None, **query):
raise NotImplementedError()
@@ -528,9 +543,7 @@ class MemoryAgent(Agent):
continue
# Fetch potential memories for this query.
raw_results = await self.get(
formatter(query), limit=limit, **where
)
raw_results = await self.get(formatter(query), limit=limit, **where)
# Apply filter and respect the `iterate` limit for this query.
accepted: list[str] = []
@@ -591,9 +604,9 @@ class MemoryAgent(Agent):
Returns a dictionary with 'cosine_similarity' and 'euclidean_distance'.
"""
embed_fn = self.embedding_function
# Embed the two strings
vec1 = np.array(embed_fn([string1])[0])
vec2 = np.array(embed_fn([string2])[0])
@@ -604,17 +617,14 @@ class MemoryAgent(Agent):
# Compute Euclidean distance
euclidean_dist = np.linalg.norm(vec1 - vec2)
return {
"cosine_similarity": cosine_sim,
"euclidean_distance": euclidean_dist
}
return {"cosine_similarity": cosine_sim, "euclidean_distance": euclidean_dist}
async def compare_string_lists(
self,
list_a: list[str],
list_b: list[str],
similarity_threshold: float = None,
distance_threshold: float = None
distance_threshold: float = None,
) -> dict:
"""
Compare two lists of strings using the current embedding function without touching the database.
@@ -625,15 +635,21 @@ class MemoryAgent(Agent):
- 'similarity_matches': list of (i, j, score) (filtered if threshold set, otherwise all)
- 'distance_matches': list of (i, j, distance) (filtered if threshold set, otherwise all)
"""
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
if (
not self.db
or not hasattr(self.db, "_embedding_function")
or self.db._embedding_function is None
):
raise RuntimeError(
"Embedding function is not initialized. Make sure the database is set."
)
if not list_a or not list_b:
return {
"cosine_similarity_matrix": np.array([]),
"euclidean_distance_matrix": np.array([]),
"similarity_matches": [],
"distance_matches": []
"distance_matches": [],
}
embed_fn = self.db._embedding_function
@@ -653,21 +669,29 @@ class MemoryAgent(Agent):
cosine_similarity_matrix = np.dot(vecs_a_norm, vecs_b_norm.T)
# Euclidean distance matrix
a_squared = np.sum(vecs_a ** 2, axis=1).reshape(-1, 1)
b_squared = np.sum(vecs_b ** 2, axis=1).reshape(1, -1)
euclidean_distance_matrix = np.sqrt(a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T))
a_squared = np.sum(vecs_a**2, axis=1).reshape(-1, 1)
b_squared = np.sum(vecs_b**2, axis=1).reshape(1, -1)
euclidean_distance_matrix = np.sqrt(
a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T)
)
# Prepare matches
similarity_matches = []
distance_matches = []
# Populate similarity matches
sim_indices = np.argwhere(cosine_similarity_matrix >= (similarity_threshold if similarity_threshold is not None else -np.inf))
sim_indices = np.argwhere(
cosine_similarity_matrix
>= (similarity_threshold if similarity_threshold is not None else -np.inf)
)
for i, j in sim_indices:
similarity_matches.append((i, j, cosine_similarity_matrix[i, j]))
# Populate distance matches
dist_indices = np.argwhere(euclidean_distance_matrix <= (distance_threshold if distance_threshold is not None else np.inf))
dist_indices = np.argwhere(
euclidean_distance_matrix
<= (distance_threshold if distance_threshold is not None else np.inf)
)
for i, j in dist_indices:
distance_matches.append((i, j, euclidean_distance_matrix[i, j]))
@@ -675,9 +699,10 @@ class MemoryAgent(Agent):
"cosine_similarity_matrix": cosine_similarity_matrix,
"euclidean_distance_matrix": euclidean_distance_matrix,
"similarity_matches": similarity_matches,
"distance_matches": distance_matches
"distance_matches": distance_matches,
}
@register(condition=lambda: chromadb is not None)
class ChromaDBMemoryAgent(MemoryAgent):
requires_llm_client = False
@@ -690,32 +715,34 @@ class ChromaDBMemoryAgent(MemoryAgent):
if getattr(self, "db_client", None):
return True
return False
@property
def client_api_ready(self) -> bool:
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
embeddings_client: ClientBase | None = instance.get_client(
self.embeddings_client
)
if not embeddings_client:
return False
if not embeddings_client.supports_embeddings:
return False
if not embeddings_client.embeddings_status:
return False
if embeddings_client.current_status not in ["idle", "busy"]:
return False
return True
return False
@property
def status(self):
if self.using_client_api_embeddings and not self.client_api_ready:
return "error"
if self.ready:
return "active" if not getattr(self, "processing", False) else "busy"
@@ -726,7 +753,6 @@ class ChromaDBMemoryAgent(MemoryAgent):
@property
def agent_details(self):
details = {
"backend": AgentDetail(
icon="mdi-server-outline",
@@ -738,23 +764,22 @@ class ChromaDBMemoryAgent(MemoryAgent):
value=self.embeddings,
description="The embeddings type.",
).model_dump(),
}
if self.model:
details["model"] = AgentDetail(
icon="mdi-brain",
value=self.model,
description="The embeddings model.",
).model_dump()
if self.embeddings_client:
details["client"] = AgentDetail(
icon="mdi-network-outline",
value=self.embeddings_client,
description="The client to use for embeddings.",
).model_dump()
if self.using_local_embeddings:
details["device"] = AgentDetail(
icon="mdi-memory",
@@ -770,9 +795,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
"color": "error",
}
if self.using_client_api_embeddings:
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
embeddings_client: ClientBase | None = instance.get_client(
self.embeddings_client
)
if not embeddings_client:
details["error"] = {
@@ -784,7 +811,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
return details
client_name = embeddings_client.name
if not embeddings_client.supports_embeddings:
error_message = f"{client_name} does not support embeddings"
elif embeddings_client.current_status not in ["idle", "busy"]:
@@ -793,7 +820,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
error_message = f"{client_name} has no embeddings model loaded"
else:
error_message = None
if error_message:
details["error"] = {
"icon": "mdi-alert",
@@ -814,8 +841,14 @@ class ChromaDBMemoryAgent(MemoryAgent):
@property
def embedding_function(self) -> Callable:
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
if (
not self.db
or not hasattr(self.db, "_embedding_function")
or self.db._embedding_function is None
):
raise RuntimeError(
"Embedding function is not initialized. Make sure the database is set."
)
embed_fn = self.db._embedding_function
return embed_fn
@@ -823,18 +856,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
def make_collection_name(self, scene) -> str:
# generate plain text collection name
collection_name = f"{self.fingerprint}"
# chromadb collection names have the following rules:
# Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address
# Step 1: Hash the input string using MD5
md5_hash = hashlib.md5(collection_name.encode()).hexdigest()
# Step 2: Ensure the result is exactly 32 characters long
hashed_collection_name = md5_hash[:32]
return f"{scene.memory_id}-tm-{hashed_collection_name}"
async def count(self):
await asyncio.sleep(0)
return self.db.count()
@@ -853,14 +886,17 @@ class ChromaDBMemoryAgent(MemoryAgent):
self.collection_name = collection_name = self.make_collection_name(self.scene)
log.info(
"chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings
"chromadb agent",
status="setting up db",
collection_name=collection_name,
embeddings=self.embeddings,
)
distance_function = self.distance_function
collection_metadata = {"hnsw:space": distance_function}
device = self.actions["_config"].config["device"].value
device = self.actions["_config"].config["device"].value
model_name = self.model
if self.using_openai_embeddings:
if not openai_key:
raise ValueError(
@@ -878,7 +914,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
model_name=model_name,
)
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=openai_ef, metadata=collection_metadata
collection_name,
embedding_function=openai_ef,
metadata=collection_metadata,
)
elif self.using_client_api_embeddings:
log.info(
@@ -886,20 +924,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
embeddings="Client API",
client=self.embeddings_client,
)
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
embeddings_client: ClientBase | None = instance.get_client(
self.embeddings_client
)
if not embeddings_client:
raise ValueError(f"Client API embeddings client {self.embeddings_client} not found")
raise ValueError(
f"Client API embeddings client {self.embeddings_client} not found"
)
if not embeddings_client.supports_embeddings:
raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings")
raise ValueError(
f"Client API embeddings client {self.embeddings_client} does not support embeddings"
)
ef = embeddings_client.embeddings_function
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef, metadata=collection_metadata
)
elif self.using_instructor_embeddings:
log.info(
"chromadb",
@@ -909,7 +953,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
)
ef = embedding_functions.InstructorEmbeddingFunction(
model_name=model_name, device=device, instruction="Represent the document for retrieval:"
model_name=model_name,
device=device,
instruction="Represent the document for retrieval:",
)
log.info("chromadb", status="embedding function ready")
@@ -919,25 +965,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
log.info("chromadb", status="instructor db ready")
else:
log.info(
"chromadb",
embeddings="SentenceTransformer",
"chromadb",
embeddings="SentenceTransformer",
model=model_name,
device=device,
distance_function=distance_function
distance_function=distance_function,
)
try:
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=model_name,
trust_remote_code=self.trust_remote_code,
device=device
device=device,
)
except ValueError as e:
if "`trust_remote_code=True` to remove this error" in str(e):
raise EmbeddingsModelLoadError(model_name, "Model requires `Trust remote code` to be enabled")
raise EmbeddingsModelLoadError(
model_name, "Model requires `Trust remote code` to be enabled"
)
raise EmbeddingsModelLoadError(model_name, str(e))
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef, metadata=collection_metadata
)
@@ -989,9 +1036,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
try:
self.db_client.delete_collection(collection_name)
except chromadb.errors.NotFoundError as exc:
log.error(
"chromadb agent", error="collection not found", details=exc
)
log.error("chromadb agent", error="collection not found", details=exc)
except ValueError as exc:
log.error(
"chromadb agent", error="failed to delete collection", details=exc
@@ -1049,16 +1094,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
if not objects:
return
# track seen documents by id
seen_ids = set()
for obj in objects:
if obj["id"] in seen_ids:
log.warning("chromadb agent", status="duplicate id discarded", id=obj["id"])
log.warning(
"chromadb agent", status="duplicate id discarded", id=obj["id"]
)
continue
seen_ids.add(obj["id"])
documents.append(obj["text"])
meta = obj.get("meta", {})
source = meta.get("source", "talemate")
@@ -1071,7 +1118,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
metadatas.append(meta)
uid = obj.get("id", f"{character}-{self.memory_tracker[character]}")
ids.append(uid)
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
def _delete(self, meta: dict):
@@ -1081,11 +1128,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
return
where = {"$and": [{k: v} for k, v in meta.items()]}
# if there is only one item in $and reduce it to the key value pair
if len(where["$and"]) == 1:
where = where["$and"][0]
self.db.delete(where=where)
log.debug("chromadb agent delete", meta=meta, where=where)
@@ -1122,16 +1169,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
log.error("chromadb agent", error="failed to query", details=e)
return []
#import json
#print(json.dumps(_results["ids"], indent=2))
#print(json.dumps(_results["distances"], indent=2))
# import json
# print(json.dumps(_results["ids"], indent=2))
# print(json.dumps(_results["distances"], indent=2))
results = []
max_distance = self.max_distance
closest = None
active_memory_request = memory_request.get()
for i in range(len(_results["distances"][0])):
@@ -1139,7 +1186,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
doc = _results["documents"][0][i]
meta = _results["metadatas"][0][i]
active_memory_request.add_result(doc, distance, meta)
if not meta:
@@ -1150,7 +1197,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
# skip pin_only entries
if meta.get("pin_only", False):
continue
if closest is None:
closest = {"distance": distance, "doc": doc}
elif distance < closest["distance"]:
@@ -1172,13 +1219,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
if len(results) > limit:
break
log.debug("chromadb agent get", closest=closest, max_distance=max_distance)
self.last_query = {
"query": text,
"closest": closest,
}
return results
def convert_ts_to_date_prefix(self, ts):
@@ -1197,10 +1244,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
return None
def _get_document(self, id) -> dict:
if not id:
return {}
result = self.db.get(ids=[id] if isinstance(id, str) else id)
documents = {}

View File

@@ -1,5 +1,5 @@
"""
Context manager that collects and tracks memory agent requests
Context manager that collects and tracks memory agent requests
for profiling and debugging purposes
"""
@@ -11,91 +11,118 @@ import time
from talemate.emit import emit
from talemate.agents.context import active_agent
__all__ = [
"MemoryRequest",
"start_memory_request"
"MemoryRequestState"
"memory_request"
]
__all__ = ["MemoryRequest"]
log = structlog.get_logger()
DEBUG_MEMORY_REQUESTS = False
class MemoryRequestResult(pydantic.BaseModel):
doc: str
distance: float
meta: dict = pydantic.Field(default_factory=dict)
class MemoryRequestState(pydantic.BaseModel):
query:str
query: str
results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
accepted_results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
query_params: dict = pydantic.Field(default_factory=dict)
closest_distance: float | None = None
furthest_distance: float | None = None
max_distance: float | None = None
def add_result(self, doc:str, distance:float, meta:dict):
def add_result(self, doc: str, distance: float, meta: dict):
if doc is None:
return
self.results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
self.closest_distance = min(self.closest_distance, distance) if self.closest_distance is not None else distance
self.furthest_distance = max(self.furthest_distance, distance) if self.furthest_distance is not None else distance
def accept_result(self, doc:str, distance:float, meta:dict):
self.closest_distance = (
min(self.closest_distance, distance)
if self.closest_distance is not None
else distance
)
self.furthest_distance = (
max(self.furthest_distance, distance)
if self.furthest_distance is not None
else distance
)
def accept_result(self, doc: str, distance: float, meta: dict):
if doc is None:
return
self.accepted_results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
self.accepted_results.append(
MemoryRequestResult(doc=doc, distance=distance, meta=meta)
)
@property
def closest_text(self):
return str(self.results[0].doc) if self.results else None
memory_request = contextvars.ContextVar("memory_request", default=None)
class MemoryRequest:
def __init__(self, query:str, query_params:dict=None):
def __init__(self, query: str, query_params: dict = None):
self.query = query
self.query_params = query_params
def __enter__(self):
self.state = MemoryRequestState(query=self.query, query_params=self.query_params)
self.state = MemoryRequestState(
query=self.query, query_params=self.query_params
)
self.token = memory_request.set(self.state)
self.time_start = time.time()
return self.state
def __exit__(self, *args):
self.time_end = time.time()
if DEBUG_MEMORY_REQUESTS:
max_length = 50
query = self.state.query[:max_length]+"..." if len(self.state.query) > max_length else self.state.query
log.debug("MemoryRequest", number_of_results=len(self.state.results), query=query)
log.debug("MemoryRequest", number_of_accepted_results=len(self.state.accepted_results), query=query)
query = (
self.state.query[:max_length] + "..."
if len(self.state.query) > max_length
else self.state.query
)
log.debug(
"MemoryRequest", number_of_results=len(self.state.results), query=query
)
log.debug(
"MemoryRequest",
number_of_accepted_results=len(self.state.accepted_results),
query=query,
)
for result in self.state.results:
# distance to 2 decimal places
log.debug("MemoryRequest RESULT", distance=f"{result.distance:.2f}", doc=result.doc[:max_length]+"...")
log.debug(
"MemoryRequest RESULT",
distance=f"{result.distance:.2f}",
doc=result.doc[:max_length] + "...",
)
agent_context = active_agent.get()
emit("memory_request", data=self.state.model_dump(), meta={
"agent_stack": agent_context.agent_stack if agent_context else [],
"agent_stack_uid": agent_context.agent_stack_uid if agent_context else None,
"duration": self.time_end - self.time_start,
}, websocket_passthrough=True)
emit(
"memory_request",
data=self.state.model_dump(),
meta={
"agent_stack": agent_context.agent_stack if agent_context else [],
"agent_stack_uid": agent_context.agent_stack_uid
if agent_context
else None,
"duration": self.time_end - self.time_start,
},
websocket_passthrough=True,
)
memory_request.reset(self.token)
return False
# decorator that opens a memory request context
async def start_memory_request(query):
@@ -103,5 +130,7 @@ async def start_memory_request(query):
async def wrapper(*args, **kwargs):
with MemoryRequest(query):
return await fn(*args, **kwargs)
return wrapper
return decorator
return decorator

View File

@@ -1,18 +1,17 @@
__all__ = [
'EmbeddingsModelLoadError',
'MemoryAgentError',
'SetDBError'
]
__all__ = ["EmbeddingsModelLoadError", "MemoryAgentError", "SetDBError"]
class MemoryAgentError(Exception):
pass
class SetDBError(OSError, MemoryAgentError):
def __init__(self, details:str):
def __init__(self, details: str):
super().__init__(f"Memory Agent - Failed to set up the database: {details}")
class EmbeddingsModelLoadError(ValueError, MemoryAgentError):
def __init__(self, model_name:str, details:str):
super().__init__(f"Memory Agent - Failed to load embeddings model {model_name}: {details}")
def __init__(self, model_name: str, details: str):
super().__init__(
f"Memory Agent - Failed to load embeddings model {model_name}: {details}"
)

View File

@@ -4,7 +4,6 @@ from talemate.agents.base import (
AgentAction,
AgentActionConfig,
)
from talemate.emit import emit
import talemate.instance as instance
if TYPE_CHECKING:
@@ -14,11 +13,10 @@ __all__ = ["MemoryRAGMixin"]
log = structlog.get_logger()
class MemoryRAGMixin:
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["use_long_term_memory"] = AgentAction(
enabled=True,
container=True,
@@ -44,7 +42,7 @@ class MemoryRAGMixin:
{
"label": "AI compiled question and answers (slow)",
"value": "questions",
}
},
],
),
"number_of_queries": AgentActionConfig(
@@ -65,7 +63,7 @@ class MemoryRAGMixin:
{"label": "Short (256)", "value": "256"},
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"},
]
],
),
"cache": AgentActionConfig(
type="bool",
@@ -73,16 +71,16 @@ class MemoryRAGMixin:
description="Cache the long term memory for faster retrieval.",
note="This is a cross-agent cache, assuming they use the same options.",
value=True,
)
),
},
)
# config property helpers
@property
def long_term_memory_enabled(self):
return self.actions["use_long_term_memory"].enabled
@property
def long_term_memory_retrieval_method(self):
return self.actions["use_long_term_memory"].config["retrieval_method"].value
@@ -90,60 +88,60 @@ class MemoryRAGMixin:
@property
def long_term_memory_number_of_queries(self):
return self.actions["use_long_term_memory"].config["number_of_queries"].value
@property
def long_term_memory_answer_length(self):
return int(self.actions["use_long_term_memory"].config["answer_length"].value)
@property
def long_term_memory_cache(self):
return self.actions["use_long_term_memory"].config["cache"].value
@property
def long_term_memory_cache_key(self):
"""
Build the key from the various options
"""
parts = [
self.long_term_memory_retrieval_method,
self.long_term_memory_number_of_queries,
self.long_term_memory_answer_length
self.long_term_memory_answer_length,
]
return "-".join(map(str, parts))
def connect(self, scene):
super().connect(scene)
# new scene, reset cache
scene.rag_cache = {}
# methods
async def rag_set_cache(self, content:list[str]):
async def rag_set_cache(self, content: list[str]):
self.scene.rag_cache[self.long_term_memory_cache_key] = {
"content": content,
"fingerprint": self.scene.history[-1].fingerprint if self.scene.history else 0
"fingerprint": self.scene.history[-1].fingerprint
if self.scene.history
else 0,
}
async def rag_get_cache(self) -> list[str] | None:
if not self.long_term_memory_cache:
return None
fingerprint = self.scene.history[-1].fingerprint if self.scene.history else 0
cache = self.scene.rag_cache.get(self.long_term_memory_cache_key)
if cache and cache["fingerprint"] == fingerprint:
return cache["content"]
return None
async def rag_build(
self,
character: "Character | None" = None,
self,
character: "Character | None" = None,
prompt: str = "",
sub_instruction: str = "",
) -> list[str]:
@@ -153,37 +151,41 @@ class MemoryRAGMixin:
if not self.long_term_memory_enabled:
return []
cached = await self.rag_get_cache()
if cached:
log.debug(f"Using cached long term memory", agent=self.agent_type, key=self.long_term_memory_cache_key)
log.debug(
"Using cached long term memory",
agent=self.agent_type,
key=self.long_term_memory_cache_key,
)
return cached
memory_context = ""
retrieval_method = self.long_term_memory_retrieval_method
if not sub_instruction:
if character:
sub_instruction = f"continue the scene as {character.name}"
elif hasattr(self, "rag_build_sub_instruction"):
sub_instruction = await self.rag_build_sub_instruction()
if not sub_instruction:
sub_instruction = "continue the scene"
if retrieval_method != "direct":
world_state = instance.get_agent("world_state")
if not prompt:
prompt = self.scene.context_history(
keep_director=False,
budget=int(self.client.max_token_length * 0.75),
)
if isinstance(prompt, list):
prompt = "\n".join(prompt)
log.debug(
"memory_rag_mixin.build_prompt_default_memory",
direct=False,
@@ -193,20 +195,21 @@ class MemoryRAGMixin:
if retrieval_method == "questions":
memory_context = (
await world_state.analyze_text_and_extract_context(
prompt, sub_instruction,
prompt,
sub_instruction,
include_character_context=True,
response_length=self.long_term_memory_answer_length,
num_queries=self.long_term_memory_number_of_queries
num_queries=self.long_term_memory_number_of_queries,
)
).split("\n")
elif retrieval_method == "queries":
memory_context = (
await world_state.analyze_text_and_extract_context_via_queries(
prompt, sub_instruction,
prompt,
sub_instruction,
include_character_context=True,
response_length=self.long_term_memory_answer_length,
num_queries=self.long_term_memory_number_of_queries
num_queries=self.long_term_memory_number_of_queries,
)
)
@@ -223,4 +226,4 @@ class MemoryRAGMixin:
await self.rag_set_cache(memory_context)
return memory_context
return memory_context

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import dataclasses
import random
from functools import wraps
from inspect import signature
from typing import TYPE_CHECKING
import structlog
@@ -50,15 +49,18 @@ log = structlog.get_logger("talemate.agents.narrator")
class NarratorAgentEmission(AgentEmission):
generation: list[str] = dataclasses.field(default_factory=list)
response: str = dataclasses.field(default="")
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
talemate.emit.async_signals.register(
"agent.narrator.before_generate",
"agent.narrator.before_generate",
"agent.narrator.inject_instructions",
"agent.narrator.generated",
)
def set_processing(fn):
"""
Custom decorator that emits the agent status as processing while the function
@@ -74,11 +76,15 @@ def set_processing(fn):
if self.content_use_writing_style:
self.set_context_states(writing_style=self.scene.writing_style)
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(emission)
await talemate.emit.async_signals.get("agent.narrator.inject_instructions").send(emission)
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(
emission
)
await talemate.emit.async_signals.get(
"agent.narrator.inject_instructions"
).send(emission)
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
response = await fn(self, *args, **kwargs)
emission.response = response
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
@@ -88,10 +94,7 @@ def set_processing(fn):
@register()
class NarratorAgent(
MemoryRAGMixin,
Agent
):
class NarratorAgent(MemoryRAGMixin, Agent):
"""
Handles narration of the story
"""
@@ -99,7 +102,7 @@ class NarratorAgent(
agent_type = "narrator"
verbose_name = "Narrator"
set_processing = set_processing
websocket_handler = NarratorWebsocketHandler
@classmethod
@@ -117,7 +120,7 @@ class NarratorAgent(
min=32,
max=1024,
step=32,
),
),
"instructions": AgentActionConfig(
type="text",
label="Instructions",
@@ -154,7 +157,7 @@ class NarratorAgent(
description="Use the writing style selected in the scene settings",
value=True,
),
}
},
),
"narrate_time_passage": AgentAction(
enabled=True,
@@ -201,7 +204,7 @@ class NarratorAgent(
},
),
}
MemoryRAGMixin.add_actions(actions)
return actions
@@ -232,28 +235,33 @@ class NarratorAgent(
if self.actions["generation_override"].enabled:
return self.actions["generation_override"].config["length"].value
return 128
@property
def narrate_time_passage_enabled(self) -> bool:
return self.actions["narrate_time_passage"].enabled
@property
def narrate_dialogue_enabled(self) -> bool:
return self.actions["narrate_dialogue"].enabled
@property
def narrate_dialogue_ai_chance(self) -> float:
def narrate_dialogue_ai_chance(self) -> float:
return self.actions["narrate_dialogue"].config["ai_dialog"].value
@property
def narrate_dialogue_player_chance(self) -> float:
return self.actions["narrate_dialogue"].config["player_dialog"].value
@property
def content_use_writing_style(self) -> bool:
return self.actions["content"].config["use_writing_style"].value
def clean_result(self, result:str, ensure_dialog_format:bool=True, force_narrative:bool=True) -> str:
def clean_result(
self,
result: str,
ensure_dialog_format: bool = True,
force_narrative: bool = True,
) -> str:
"""
Cleans the result of a narration
"""
@@ -264,15 +272,14 @@ class NarratorAgent(
cleaned = []
for line in result.split("\n"):
# skip lines that start with a #
if line.startswith("#"):
continue
log.debug("clean_result", line=line)
character_dialogue_detected = False
for character_name in character_names:
if not character_name:
continue
@@ -280,25 +287,24 @@ class NarratorAgent(
character_dialogue_detected = True
elif line.startswith(f"{character_name.upper()}"):
character_dialogue_detected = True
if character_dialogue_detected:
break
if character_dialogue_detected:
break
cleaned.append(line)
result = "\n".join(cleaned)
result = util.strip_partial_sentences(result)
editor = get_agent("editor")
if ensure_dialog_format or force_narrative:
if editor.fix_exposition_enabled and editor.fix_exposition_narrator:
result = editor.fix_exposition_in_text(result)
return result
def connect(self, scene):
@@ -324,16 +330,16 @@ class NarratorAgent(
event.duration, event.human_duration, event.narrative
)
narrator_message = NarratorMessage(
response,
meta = {
response,
meta={
"agent": "narrator",
"function": "narrate_time_passage",
"arguments": {
"duration": event.duration,
"time_passed": event.human_duration,
"narrative_direction": event.narrative,
}
}
},
},
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
@@ -345,7 +351,7 @@ class NarratorAgent(
if not self.narrate_dialogue_enabled:
return
if event.game_loop.had_passive_narration:
log.debug(
"narrate on dialog",
@@ -375,14 +381,14 @@ class NarratorAgent(
response = await self.narrate_after_dialogue(event.actor.character)
narrator_message = NarratorMessage(
response,
response,
meta={
"agent": "narrator",
"function": "narrate_after_dialogue",
"arguments": {
"character": event.actor.character.name,
}
}
},
},
)
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
@@ -390,7 +396,7 @@ class NarratorAgent(
event.game_loop.had_passive_narration = True
@set_processing
@store_context_state('narrative_direction', visual_narration=True)
@store_context_state("narrative_direction", visual_narration=True)
async def narrate_scene(self, narrative_direction: str | None = None):
"""
Narrate the scene
@@ -413,13 +419,13 @@ class NarratorAgent(
return response
@set_processing
@store_context_state('narrative_direction')
@store_context_state("narrative_direction")
async def progress_story(self, narrative_direction: str | None = None):
"""
Narrate scene progression, moving the plot forward.
Arguments:
- narrative_direction: A string describing the direction the narrative should take. If not provided, will attempt to subtly move the story forward.
"""
@@ -431,10 +437,8 @@ class NarratorAgent(
if narrative_direction is None:
narrative_direction = "Slightly move the current scene forward."
log.debug(
"narrative_direction", narrative_direction=narrative_direction
)
log.debug("narrative_direction", narrative_direction=narrative_direction)
response = await Prompt.request(
"narrator.narrate-progress",
self.client,
@@ -453,13 +457,17 @@ class NarratorAgent(
log.debug("progress_story", response=response)
response = self.clean_result(response.strip())
return response
@set_processing
@store_context_state('query', query_narration=True)
@store_context_state("query", query_narration=True)
async def narrate_query(
self, query: str, at_the_end: bool = False, as_narrative: bool = True, extra_context: str = None
self,
query: str,
at_the_end: bool = False,
as_narrative: bool = True,
extra_context: str = None,
):
"""
Narrate a specific query
@@ -479,20 +487,20 @@ class NarratorAgent(
},
)
response = self.clean_result(
response.strip(),
ensure_dialog_format=False,
force_narrative=as_narrative
response.strip(), ensure_dialog_format=False, force_narrative=as_narrative
)
return response
@set_processing
@store_context_state('character', 'narrative_direction', visual_narration=True)
async def narrate_character(self, character:"Character", narrative_direction: str = None):
@store_context_state("character", "narrative_direction", visual_narration=True)
async def narrate_character(
self, character: "Character", narrative_direction: str = None
):
"""
Narrate a specific character
"""
response = await Prompt.request(
"narrator.narrate-character",
self.client,
@@ -506,12 +514,14 @@ class NarratorAgent(
},
)
response = self.clean_result(response.strip(), ensure_dialog_format=False, force_narrative=True)
response = self.clean_result(
response.strip(), ensure_dialog_format=False, force_narrative=True
)
return response
@set_processing
@store_context_state('narrative_direction', time_narration=True)
@store_context_state("narrative_direction", time_narration=True)
async def narrate_time_passage(
self, duration: str, time_passed: str, narrative_direction: str
):
@@ -528,7 +538,7 @@ class NarratorAgent(
"max_tokens": self.client.max_token_length,
"duration": duration,
"time_passed": time_passed,
"narrative": narrative_direction, # backwards compatibility
"narrative": narrative_direction, # backwards compatibility
"narrative_direction": narrative_direction,
"extra_instructions": self.extra_instructions,
},
@@ -541,7 +551,7 @@ class NarratorAgent(
return response
@set_processing
@store_context_state('narrative_direction', sensory_narration=True)
@store_context_state("narrative_direction", sensory_narration=True)
async def narrate_after_dialogue(
self,
character: Character,
@@ -572,16 +582,16 @@ class NarratorAgent(
async def narrate_environment(self, narrative_direction: str = None):
"""
Narrate the environment
Wraps narrate_after_dialogue with the player character
as the perspective character
"""
pc = self.scene.get_player_character()
return await self.narrate_after_dialogue(pc, narrative_direction)
@set_processing
@store_context_state('narrative_direction', 'character')
@store_context_state("narrative_direction", "character")
async def narrate_character_entry(
self, character: Character, narrative_direction: str = None
):
@@ -607,11 +617,9 @@ class NarratorAgent(
return response
@set_processing
@store_context_state('narrative_direction', 'character')
@store_context_state("narrative_direction", "character")
async def narrate_character_exit(
self,
character: Character,
narrative_direction: str = None
self, character: Character, narrative_direction: str = None
):
"""
Narrate a character exiting the scene
@@ -679,8 +687,10 @@ class NarratorAgent(
later on
"""
args = parameters.copy()
if args.get("character") and isinstance(args["character"], self.scene.Character):
if args.get("character") and isinstance(
args["character"], self.scene.Character
):
args["character"] = args["character"].name
return {
@@ -701,7 +711,9 @@ class NarratorAgent(
fn = getattr(self, action_name)
narration = await fn(**kwargs)
narrator_message = NarratorMessage(narration, meta=self.action_to_meta(action_name, kwargs))
narrator_message = NarratorMessage(
narration, meta=self.action_to_meta(action_name, kwargs)
)
self.scene.push_history(narrator_message)
if emit_message:
@@ -720,20 +732,22 @@ class NarratorAgent(
kind=kind,
agent_function_name=agent_function_name,
)
# depending on conversation format in the context, stopping strings
# for character names may change format
conversation_agent = get_agent("conversation")
if conversation_agent.conversation_format == "movie_script":
character_names = [f"\n{c.name.upper()}\n" for c in self.scene.get_characters()]
else:
character_names = [
f"\n{c.name.upper()}\n" for c in self.scene.get_characters()
]
else:
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
if prompt_param.get("extra_stopping_strings") is None:
prompt_param["extra_stopping_strings"] = []
prompt_param["extra_stopping_strings"] += character_names
self.set_generation_overrides(prompt_param)
def allow_repetition_break(
@@ -748,7 +762,9 @@ class NarratorAgent(
if not self.actions["generation_override"].enabled:
return
prompt_param["max_tokens"] = min(prompt_param.get("max_tokens", 256), self.max_generation_length)
prompt_param["max_tokens"] = min(
prompt_param.get("max_tokens", 256), self.max_generation_length
)
if self.jiggle > 0.0:
nuke_repetition = client_context_attribute("nuke_repetition")
@@ -756,4 +772,4 @@ class NarratorAgent(
# we only apply the agent override if some other mechanism isn't already
# setting the nuke_repetition value
nuke_repetition = self.jiggle
set_client_context_attribute("nuke_repetition", nuke_repetition)
set_client_context_attribute("nuke_repetition", nuke_repetition)

View File

@@ -8,15 +8,16 @@ from talemate.util import iso8601_duration_to_human
log = structlog.get_logger("talemate.game.engine.nodes.agents.narrator")
class GenerateNarrationBase(AgentNode):
"""
Generate a narration message
"""
_agent_name:ClassVar[str] = "narrator"
_action_name:ClassVar[str] = ""
_title:ClassVar[str] = "Generate Narration"
_agent_name: ClassVar[str] = "narrator"
_action_name: ClassVar[str] = ""
_title: ClassVar[str] = "Generate Narration"
class Fields:
narrative_direction = PropertyField(
name="narrative_direction",
@@ -24,151 +25,170 @@ class GenerateNarrationBase(AgentNode):
default="",
type="str",
)
def __init__(self, **kwargs):
if "title" not in kwargs:
kwargs["title"] = self._title
super().__init__(**kwargs)
def setup(self):
self.add_input("state")
self.add_input("narrative_direction", socket_type="str", optional=True)
self.add_output("generated", socket_type="str")
self.add_output("message", socket_type="message_object")
async def prepare_input_values(self) -> dict:
input_values = self.get_input_values()
input_values.pop("state", None)
return input_values
async def run(self, state: GraphState):
input_values = await self.prepare_input_values()
try:
agent_fn = getattr(self.agent, self._action_name)
except AttributeError:
raise InputValueError(self, "_action_name", f"Agent does not have a function named {self._action_name}")
raise InputValueError(
self,
"_action_name",
f"Agent does not have a function named {self._action_name}",
)
narration = await agent_fn(**input_values)
message = NarratorMessage(
message=narration,
meta=self.agent.action_to_meta(self._action_name, input_values),
)
self.set_output_values({
"generated": narration,
"message": message
})
self.set_output_values({"generated": narration, "message": message})
@register("agents/narrator/GenerateProgress")
class GenerateProgressNarration(GenerateNarrationBase):
"""
Generate a progress narration message
"""
_action_name:ClassVar[str] = "progress_story"
_title:ClassVar[str] = "Generate Progress Narration"
"""
_action_name: ClassVar[str] = "progress_story"
_title: ClassVar[str] = "Generate Progress Narration"
@register("agents/narrator/GenerateSceneNarration")
class GenerateSceneNarration(GenerateNarrationBase):
"""
Generate a scene narration message
"""
_action_name:ClassVar[str] = "narrate_scene"
_title:ClassVar[str] = "Generate Scene Narration"
"""
@register("agents/narrator/GenerateAfterDialogNarration")
_action_name: ClassVar[str] = "narrate_scene"
_title: ClassVar[str] = "Generate Scene Narration"
@register("agents/narrator/GenerateAfterDialogNarration")
class GenerateAfterDialogNarration(GenerateNarrationBase):
"""
Generate an after dialog narration message
"""
_action_name:ClassVar[str] = "narrate_after_dialogue"
_title:ClassVar[str] = "Generate After Dialog Narration"
"""
_action_name: ClassVar[str] = "narrate_after_dialogue"
_title: ClassVar[str] = "Generate After Dialog Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/GenerateEnvironmentNarration")
class GenerateEnvironmentNarration(GenerateNarrationBase):
"""
Generate an environment narration message
"""
_action_name:ClassVar[str] = "narrate_environment"
_title:ClassVar[str] = "Generate Environment Narration"
"""
_action_name: ClassVar[str] = "narrate_environment"
_title: ClassVar[str] = "Generate Environment Narration"
@register("agents/narrator/GenerateQueryNarration")
class GenerateQueryNarration(GenerateNarrationBase):
"""
Generate a query narration message
"""
_action_name:ClassVar[str] = "narrate_query"
_title:ClassVar[str] = "Generate Query Narration"
"""
_action_name: ClassVar[str] = "narrate_query"
_title: ClassVar[str] = "Generate Query Narration"
def setup(self):
super().setup()
self.add_input("query", socket_type="str")
self.add_input("extra_context", socket_type="str", optional=True)
self.remove_input("narrative_direction")
@register("agents/narrator/GenerateCharacterNarration")
class GenerateCharacterNarration(GenerateNarrationBase):
"""
Generate a character narration message
"""
_action_name:ClassVar[str] = "narrate_character"
_title:ClassVar[str] = "Generate Character Narration"
"""
_action_name: ClassVar[str] = "narrate_character"
_title: ClassVar[str] = "Generate Character Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/GenerateTimeNarration")
class GenerateTimeNarration(GenerateNarrationBase):
"""
Generate a time narration message
"""
_action_name:ClassVar[str] = "narrate_time_passage"
_title:ClassVar[str] = "Generate Time Narration"
"""
_action_name: ClassVar[str] = "narrate_time_passage"
_title: ClassVar[str] = "Generate Time Narration"
def setup(self):
super().setup()
self.add_input("duration", socket_type="str")
self.set_property("duration", "P0T1S")
async def prepare_input_values(self) -> dict:
input_values = await super().prepare_input_values()
input_values["time_passed"] = iso8601_duration_to_human(input_values["duration"])
input_values["time_passed"] = iso8601_duration_to_human(
input_values["duration"]
)
return input_values
@register("agents/narrator/GenerateCharacterEntryNarration")
class GenerateCharacterEntryNarration(GenerateNarrationBase):
"""
Generate a character entry narration message
"""
_action_name:ClassVar[str] = "narrate_character_entry"
_title:ClassVar[str] = "Generate Character Entry Narration"
"""
_action_name: ClassVar[str] = "narrate_character_entry"
_title: ClassVar[str] = "Generate Character Entry Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/GenerateCharacterExitNarration")
class GenerateCharacterExitNarration(GenerateNarrationBase):
"""
Generate a character exit narration message
"""
_action_name:ClassVar[str] = "narrate_character_exit"
_title:ClassVar[str] = "Generate Character Exit Narration"
"""
_action_name: ClassVar[str] = "narrate_character_exit"
_title: ClassVar[str] = "Generate Character Exit Narration"
def setup(self):
super().setup()
self.add_input("character", socket_type="character")
@register("agents/narrator/UnpackSource")
class UnpackSource(AgentNode):
"""
@@ -176,25 +196,19 @@ class UnpackSource(AgentNode):
into action name and arguments
DEPRECATED
"""
_agent_name:ClassVar[str] = "narrator"
_agent_name: ClassVar[str] = "narrator"
def __init__(self, title="Unpack Source", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("source", socket_type="str")
self.add_output("action_name", socket_type="str")
self.add_output("arguments", socket_type="dict")
async def run(self, state: GraphState):
source = self.get_input_value("source")
action_name = ""
arguments = {}
self.set_output_values({
"action_name": action_name,
"arguments": arguments
})
self.set_output_values({"action_name": action_name, "arguments": arguments})

View File

@@ -15,27 +15,31 @@ __all__ = [
log = structlog.get_logger("talemate.server.narrator")
class QueryPayload(pydantic.BaseModel):
query:str
at_the_end:bool=True
query: str
at_the_end: bool = True
class NarrativeDirectionPayload(pydantic.BaseModel):
narrative_direction:str = ""
narrative_direction: str = ""
class CharacterPayload(NarrativeDirectionPayload):
character:str = ""
character: str = ""
class NarratorWebsocketHandler(Plugin):
"""
Handles narrator actions
"""
router = "narrator"
@property
def narrator(self):
return get_agent("narrator")
@set_loading("Progressing the story", cancellable=True, as_async=True)
async def handle_progress(self, data: dict):
"""
@@ -47,7 +51,7 @@ class NarratorWebsocketHandler(Plugin):
narrative_direction=payload.narrative_direction,
emit_message=True,
)
@set_loading("Narrating the environment", cancellable=True, as_async=True)
async def handle_narrate_environment(self, data: dict):
"""
@@ -59,8 +63,7 @@ class NarratorWebsocketHandler(Plugin):
narrative_direction=payload.narrative_direction,
emit_message=True,
)
@set_loading("Working on a query", cancellable=True, as_async=True)
async def handle_query(self, data: dict):
"""
@@ -68,56 +71,55 @@ class NarratorWebsocketHandler(Plugin):
message.
"""
payload = QueryPayload(**data)
narration = await self.narrator.narrate_query(**payload.model_dump())
message: ContextInvestigationMessage = ContextInvestigationMessage(
narration, sub_type="query"
narration, sub_type="query"
)
message.set_source("narrator", "narrate_query", **payload.model_dump())
emit("context_investigation", message=message)
self.scene.push_history(message)
@set_loading("Looking at the scene", cancellable=True, as_async=True)
async def handle_look_at_scene(self, data: dict):
"""
Look at the scene (optionally to a specific direction)
This will result in a context investigation message.
"""
payload = NarrativeDirectionPayload(**data)
narration = await self.narrator.narrate_scene(narrative_direction=payload.narrative_direction)
narration = await self.narrator.narrate_scene(
narrative_direction=payload.narrative_direction
)
message: ContextInvestigationMessage = ContextInvestigationMessage(
narration, sub_type="visual-scene"
)
message.set_source("narrator", "narrate_scene", **payload.model_dump())
emit("context_investigation", message=message)
self.scene.push_history(message)
@set_loading("Looking at a character", cancellable=True, as_async=True)
async def handle_look_at_character(self, data: dict):
"""
Look at a character (optionally to a specific direction)
This will result in a context investigation message.
"""
payload = CharacterPayload(**data)
narration = await self.narrator.narrate_character(
character = self.scene.get_character(payload.character),
character=self.scene.get_character(payload.character),
narrative_direction=payload.narrative_direction,
)
message: ContextInvestigationMessage = ContextInvestigationMessage(
narration, sub_type="visual-character"
)
message.set_source("narrator", "narrate_character", **payload.model_dump())
emit("context_investigation", message=message)
self.scene.push_history(message)

View File

@@ -24,4 +24,4 @@ def get_agent_class(name):
def get_agent_types() -> list[str]:
return list(AGENT_CLASSES.keys())
return list(AGENT_CLASSES.keys())

View File

@@ -7,26 +7,21 @@ import structlog
from typing import TYPE_CHECKING, Literal
import talemate.emit.async_signals
import talemate.util as util
from talemate.emit import emit
from talemate.events import GameLoopEvent
from talemate.prompts import Prompt
from talemate.scene_message import (
DirectorMessage,
TimePassageMessage,
ContextInvestigationMessage,
DirectorMessage,
TimePassageMessage,
ContextInvestigationMessage,
ReinforcementMessage,
)
from talemate.world_state.templates import GenerationOptions
from talemate.instance import get_agent
from talemate.exceptions import GenerationCancelled
import talemate.game.focal as focal
import talemate.emit.async_signals
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConfig,
set_processing,
Agent,
AgentAction,
AgentActionConfig,
set_processing,
AgentEmission,
AgentTemplateEmission,
RagBuildSubInstructionEmission,
@@ -53,10 +48,12 @@ talemate.emit.async_signals.register(
"agent.summarization.summarize.after",
)
@dataclasses.dataclass
class BuildArchiveEmission(AgentEmission):
generation_options: GenerationOptions | None = None
@dataclasses.dataclass
class SummarizeEmission(AgentTemplateEmission):
text: str = ""
@@ -66,6 +63,7 @@ class SummarizeEmission(AgentTemplateEmission):
summarization_history: list[str] | None = None
summarization_type: Literal["dialogue", "events"] = "dialogue"
@register()
class SummarizeAgent(
MemoryRAGMixin,
@@ -73,7 +71,7 @@ class SummarizeAgent(
ContextInvestigationMixin,
# Needs to be after ContextInvestigationMixin so signals are connected in the right order
SceneAnalyzationMixin,
Agent
Agent,
):
"""
An agent that can be used to summarize text
@@ -82,7 +80,7 @@ class SummarizeAgent(
agent_type = "summarizer"
verbose_name = "Summarizer"
auto_squish = False
@classmethod
def init_actions(cls) -> dict[str, AgentAction]:
actions = {
@@ -148,11 +146,11 @@ class SummarizeAgent(
@property
def archive_threshold(self):
return self.actions["archive"].config["threshold"].value
@property
def archive_method(self):
return self.actions["archive"].config["method"].value
@property
def archive_include_previous(self):
return self.actions["archive"].config["include_previous"].value
@@ -179,7 +177,7 @@ class SummarizeAgent(
return result
# RAG HELPERS
async def rag_build_sub_instruction(self):
# Fire event to get the sub instruction from mixins
emission = RagBuildSubInstructionEmission(
@@ -188,37 +186,42 @@ class SummarizeAgent(
await talemate.emit.async_signals.get(
"agent.summarization.rag_build_sub_instruction"
).send(emission)
return emission.sub_instruction
# SUMMARIZATION HELPERS
async def previous_summaries(self, entry: ArchiveEntry) -> list[str]:
num_previous = self.archive_include_previous
# find entry by .id
entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None)
entry_index = next(
(
i
for i, e in enumerate(self.scene.archived_history)
if e["id"] == entry.id
),
None,
)
if entry_index is None:
raise ValueError("Entry not found")
end = entry_index - 1
previous_summaries = []
if entry and num_previous > 0:
if self.layered_history_available:
previous_summaries = self.compile_layered_history(
include_base_layer=True,
base_layer_end_id=entry.id
include_base_layer=True, base_layer_end_id=entry.id
)[-num_previous:]
else:
previous_summaries = [
entry.text for entry in self.scene.archived_history[end-num_previous:end]
entry.text
for entry in self.scene.archived_history[end - num_previous : end]
]
return previous_summaries
# SUMMARIZE
@set_processing
@@ -226,13 +229,15 @@ class SummarizeAgent(
self, scene, generation_options: GenerationOptions | None = None
):
end = None
emission = BuildArchiveEmission(
agent=self,
generation_options=generation_options,
)
await talemate.emit.async_signals.get("agent.summarization.before_build_archive").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.before_build_archive"
).send(emission)
if not self.actions["archive"].enabled:
return
@@ -260,7 +265,7 @@ class SummarizeAgent(
extra_context = [
entry["text"] for entry in scene.archived_history[-num_previous:]
]
else:
extra_context = None
@@ -283,7 +288,10 @@ class SummarizeAgent(
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
if isinstance(dialogue, (DirectorMessage, ContextInvestigationMessage, ReinforcementMessage)):
if isinstance(
dialogue,
(DirectorMessage, ContextInvestigationMessage, ReinforcementMessage),
):
# these messages are not part of the dialogue and should not be summarized
if i == start:
start += 1
@@ -292,7 +300,7 @@ class SummarizeAgent(
if isinstance(dialogue, TimePassageMessage):
log.debug("build_archive", time_passage_message=dialogue)
ts = util.iso8601_add(ts, dialogue.ts)
if i == start:
log.debug(
"build_archive",
@@ -347,23 +355,28 @@ class SummarizeAgent(
if str(line) in terminating_line:
break
adjusted_dialogue.append(line)
# if difference start and end is less than 4, ignore the termination
if len(adjusted_dialogue) > 4:
dialogue_entries = adjusted_dialogue
end = start + len(dialogue_entries) - 1
else:
log.debug("build_archive", message="Ignoring termination", start=start, end=end, adjusted_dialogue=adjusted_dialogue)
log.debug(
"build_archive",
message="Ignoring termination",
start=start,
end=end,
adjusted_dialogue=adjusted_dialogue,
)
if dialogue_entries:
if not extra_context:
# prepend scene intro to dialogue
dialogue_entries.insert(0, scene.intro)
summarized = None
retries = 5
while not summarized and retries > 0:
summarized = await self.summarize(
"\n".join(map(str, dialogue_entries)),
@@ -371,7 +384,7 @@ class SummarizeAgent(
generation_options=generation_options,
)
retries -= 1
if not summarized:
raise IOError("Failed to summarize dialogue", dialogue=dialogue_entries)
@@ -382,13 +395,17 @@ class SummarizeAgent(
# determine the appropariate timestamp for the summarization
await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts))
scene.ts=ts
await scene.push_archive(
ArchiveEntry(text=summarized, start=start, end=end, ts=ts)
)
scene.ts = ts
scene.emit_status()
await talemate.emit.async_signals.get("agent.summarization.after_build_archive").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.after_build_archive"
).send(emission)
return True
@set_processing
@@ -408,23 +425,23 @@ class SummarizeAgent(
return response
@set_processing
async def find_natural_scene_termination(self, event_chunks:list[str]) -> list[list[str]]:
async def find_natural_scene_termination(
self, event_chunks: list[str]
) -> list[list[str]]:
"""
Will analyze a list of events and return a list of events that
has been separated at a natural scene termination points.
"""
# scan through event chunks and split into paragraphs
rebuilt_chunks = []
for chunk in event_chunks:
paragraphs = [
p.strip() for p in chunk.split("\n") if p.strip()
]
paragraphs = [p.strip() for p in chunk.split("\n") if p.strip()]
rebuilt_chunks.extend(paragraphs)
event_chunks = rebuilt_chunks
response = await Prompt.request(
"summarizer.find-natural-scene-termination-events",
self.client,
@@ -436,38 +453,42 @@ class SummarizeAgent(
},
)
response = response.strip()
items = util.extract_list(response)
# will be a list of
# will be a list of
# ["Progress 1", "Progress 12", "Progress 323", ...]
# convert to a list of just numbers
# convert to a list of just numbers
numbers = []
for item in items:
match = re.match(r"Progress (\d+)", item.strip())
if match:
numbers.append(int(match.group(1)))
# make sure its unique and sorted
numbers = sorted(list(set(numbers)))
result = []
prev_number = 0
for number in numbers:
result.append(event_chunks[prev_number:number+1])
prev_number = number+1
#result = {
result.append(event_chunks[prev_number : number + 1])
prev_number = number + 1
# result = {
# "selected": event_chunks[:number+1],
# "remaining": event_chunks[number+1:]
#}
log.debug("find_natural_scene_termination", response=response, result=result, numbers=numbers)
# }
log.debug(
"find_natural_scene_termination",
response=response,
result=result,
numbers=numbers,
)
return result
@set_processing
async def summarize(
@@ -481,9 +502,9 @@ class SummarizeAgent(
"""
Summarize the given text
"""
response_length = 1024
template_vars = {
"dialogue": text,
"scene": self.scene,
@@ -500,7 +521,7 @@ class SummarizeAgent(
"analyze_chunks": self.layered_history_analyze_chunks,
"response_length": response_length,
}
emission = SummarizeEmission(
agent=self,
text=text,
@@ -511,43 +532,46 @@ class SummarizeAgent(
summarization_history=extra_context or [],
summarization_type="dialogue",
)
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.before"
).send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"summarizer.summarize-dialogue",
"summarizer.summarize-dialogue",
self.client,
f"summarize_{response_length}",
vars=template_vars,
dedupe_enabled=False
dedupe_enabled=False,
)
log.debug(
"summarize", dialogue_length=len(text), summarized_length=len(response)
)
try:
summary = response.split("SUMMARY:")[1].strip()
except Exception as e:
log.error("summarize failed", response=response, exc=e)
return ""
# capitalize first letter
try:
summary = summary[0].upper() + summary[1:]
except IndexError:
pass
emission.response = self.clean_result(summary)
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.after"
).send(emission)
summary = emission.response
return self.clean_result(summary)
@set_processing
async def summarize_events(
@@ -563,15 +587,14 @@ class SummarizeAgent(
"""
Summarize the given text
"""
if not extra_context:
extra_context = ""
mentioned_characters: list["Character"] = self.scene.parse_characters_from_text(
text + extra_context,
exclude_active=True
text + extra_context, exclude_active=True
)
template_vars = {
"dialogue": text,
"scene": self.scene,
@@ -585,7 +608,7 @@ class SummarizeAgent(
"response_length": response_length,
"mentioned_characters": mentioned_characters,
}
emission = SummarizeEmission(
agent=self,
text=text,
@@ -596,46 +619,62 @@ class SummarizeAgent(
summarization_history=[extra_context] if extra_context else [],
summarization_type="events",
)
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.before"
).send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"summarizer.summarize-events",
"summarizer.summarize-events",
self.client,
f"summarize_{response_length}",
vars=template_vars,
dedupe_enabled=False
dedupe_enabled=False,
)
response = response.strip()
response = response.replace('"', "")
log.debug(
"layered_history_summarize", original_length=len(text), summarized_length=len(response)
"layered_history_summarize",
original_length=len(text),
summarized_length=len(response),
)
# clean up analyzation (remove analyzation text)
if self.layered_history_analyze_chunks:
# remove all lines that begin with "ANALYSIS OF CHUNK \d+:"
response = "\n".join([line for line in response.split("\n") if not line.startswith("ANALYSIS OF CHUNK")])
response = "\n".join(
[
line
for line in response.split("\n")
if not line.startswith("ANALYSIS OF CHUNK")
]
)
# strip all occurences of "CHUNK \d+: " from the summary
response = re.sub(r"(CHUNK|CHAPTER) \d+:\s+", "", response)
# capitalize first letter
try:
response = response[0].upper() + response[1:]
except IndexError:
pass
emission.response = self.clean_result(response)
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.summarize.after"
).send(emission)
response = emission.response
log.debug("summarize_events", original_length=len(text), summarized_length=len(response))
log.debug(
"summarize_events",
original_length=len(text),
summarized_length=len(response),
)
return self.clean_result(response)

View File

@@ -28,13 +28,15 @@ talemate.emit.async_signals.register(
"agent.summarization.scene_analysis.after",
"agent.summarization.scene_analysis.cached",
"agent.summarization.scene_analysis.before_deep_analysis",
"agent.summarization.scene_analysis.after_deep_analysis",
"agent.summarization.scene_analysis.after_deep_analysis",
)
@dataclasses.dataclass
class SceneAnalysisEmission(AgentTemplateEmission):
analysis_type: str | None = None
@dataclasses.dataclass
class SceneAnalysisDeepAnalysisEmission(AgentEmission):
analysis: str
@@ -42,15 +44,16 @@ class SceneAnalysisDeepAnalysisEmission(AgentEmission):
analysis_sub_type: str | None = None
max_content_investigations: int = 1
character: "Character" = None
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
default_factory=list
)
class SceneAnalyzationMixin:
"""
Summarizer agent mixin that provides functionality for scene analyzation.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["analyze_scene"] = AgentAction(
@@ -71,8 +74,8 @@ class SceneAnalyzationMixin:
choices=[
{"label": "Short (256)", "value": "256"},
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"}
]
{"label": "Long (1024)", "value": "1024"},
],
),
"for_conversation": AgentActionConfig(
type="bool",
@@ -109,196 +112,204 @@ class SceneAnalyzationMixin:
value=True,
quick_toggle=True,
),
}
},
)
# config property helpers
@property
def analyze_scene(self) -> bool:
return self.actions["analyze_scene"].enabled
@property
def analysis_length(self) -> int:
return int(self.actions["analyze_scene"].config["analysis_length"].value)
@property
def cache_analysis(self) -> bool:
return self.actions["analyze_scene"].config["cache_analysis"].value
@property
def deep_analysis(self) -> bool:
return self.actions["analyze_scene"].config["deep_analysis"].value
@property
def deep_analysis_max_context_investigations(self) -> int:
return self.actions["analyze_scene"].config["deep_analysis_max_context_investigations"].value
return (
self.actions["analyze_scene"]
.config["deep_analysis_max_context_investigations"]
.value
)
@property
def analyze_scene_for_conversation(self) -> bool:
return self.actions["analyze_scene"].config["for_conversation"].value
@property
def analyze_scene_for_narration(self) -> bool:
return self.actions["analyze_scene"].config["for_narration"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
self.on_inject_instructions
)
talemate.emit.async_signals.get(
"agent.conversation.inject_instructions"
).connect(self.on_inject_instructions)
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
self.on_inject_instructions
)
talemate.emit.async_signals.get("agent.summarization.rag_build_sub_instruction").connect(
self.on_rag_build_sub_instruction
)
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
self.on_editor_revision_analysis_before
)
talemate.emit.async_signals.get(
"agent.summarization.rag_build_sub_instruction"
).connect(self.on_rag_build_sub_instruction)
talemate.emit.async_signals.get(
"agent.editor.revision-analysis.before"
).connect(self.on_editor_revision_analysis_before)
async def on_inject_instructions(
self,
emission:ConversationAgentEmission | NarratorAgentEmission,
self,
emission: ConversationAgentEmission | NarratorAgentEmission,
):
"""
Injects instructions into the conversation.
"""
if isinstance(emission, ConversationAgentEmission):
emission_type = "conversation"
elif isinstance(emission, NarratorAgentEmission):
emission_type = "narration"
else:
raise ValueError("Invalid emission type.")
if not self.analyze_scene:
return
analyze_scene_for_type = getattr(self, f"analyze_scene_for_{emission_type}")
if not analyze_scene_for_type:
return
analysis = None
# self.set_scene_states and self.get_scene_state to store
# cached analysis in scene states
if self.cache_analysis:
analysis = await self.get_cached_analysis(emission_type)
if analysis:
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").send(
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.cached"
).send(
SceneAnalysisEmission(
agent=self,
analysis_type=emission_type,
response=analysis,
agent=self,
analysis_type=emission_type,
response=analysis,
template_vars={
"character": emission.character if hasattr(emission, "character") else None,
"character": emission.character
if hasattr(emission, "character")
else None,
},
dynamic_instructions=emission.dynamic_instructions
dynamic_instructions=emission.dynamic_instructions,
)
)
if not analysis and self.analyze_scene:
# analyze the scene for the next action
analysis = await self.analyze_scene_for_next_action(
emission_type,
emission.character if hasattr(emission, "character") else None,
self.analysis_length
self.analysis_length,
)
await self.set_cached_analysis(emission_type, analysis)
if not analysis:
return
emission.dynamic_instructions.append(
DynamicInstruction(
title="Scene Analysis",
content=analysis
)
DynamicInstruction(title="Scene Analysis", content=analysis)
)
async def on_rag_build_sub_instruction(self, emission:"RagBuildSubInstructionEmission"):
async def on_rag_build_sub_instruction(
self, emission: "RagBuildSubInstructionEmission"
):
"""
Injects the sub instruction into the analysis.
"""
sub_instruction = await self.analyze_scene_rag_build_sub_instruction()
if sub_instruction:
emission.sub_instruction = sub_instruction
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
last_analysis = self.get_scene_state("scene_analysis")
if last_analysis:
emission.dynamic_instructions.append(DynamicInstruction(
title="Scene Analysis",
content=last_analysis
))
emission.dynamic_instructions.append(
DynamicInstruction(title="Scene Analysis", content=last_analysis)
)
# helpers
async def get_cached_analysis(self, typ:str) -> str | None:
async def get_cached_analysis(self, typ: str) -> str | None:
"""
Returns the cached analysis for the given type.
"""
cached_analysis = self.get_scene_state(f"cached_analysis_{typ}")
if not cached_analysis:
return None
fingerprint = self.context_fingerpint()
if cached_analysis.get("fp") == fingerprint:
return cached_analysis["guidance"]
return None
async def set_cached_analysis(self, typ:str, analysis:str):
async def set_cached_analysis(self, typ: str, analysis: str):
"""
Sets the cached analysis for the given type.
"""
fingerprint = self.context_fingerpint()
self.set_scene_states(
**{f"cached_analysis_{typ}": {
"fp": fingerprint,
"guidance": analysis,
}}
**{
f"cached_analysis_{typ}": {
"fp": fingerprint,
"guidance": analysis,
}
}
)
async def analyze_scene_sub_type(self, analysis_type:str) -> str:
async def analyze_scene_sub_type(self, analysis_type: str) -> str:
"""
Analyzes the active agent context to figure out the appropriate sub type
"""
fn = getattr(self, f"analyze_scene_{analysis_type}_sub_type", None)
if fn:
return await fn()
return ""
async def analyze_scene_narration_sub_type(self) -> str:
"""
Analyzes the active agent context to figure out the appropriate sub type
for narration analysis. (progress, query etc.)
"""
active_agent_context = active_agent.get()
if not active_agent_context:
return "progress"
state = active_agent_context.state
if state.get("narrator__query_narration"):
return "query"
if state.get("narrator__sensory_narration"):
return "sensory"
@@ -306,70 +317,85 @@ class SceneAnalyzationMixin:
if state.get("narrator__character"):
return "visual-character"
return "visual"
if state.get("narrator__fn_narrate_character_entry"):
return "progress-character-entry"
if state.get("narrator__fn_narrate_character_exit"):
return "progress-character-exit"
return "progress"
async def analyze_scene_rag_build_sub_instruction(self):
"""
Analyzes the active agent context to figure out the appropriate sub type
for rag build sub instruction.
"""
active_agent_context = active_agent.get()
if not active_agent_context:
return ""
state = active_agent_context.state
if state.get("narrator__query_narration"):
query = state["narrator__query"]
if query.endswith("?"):
return "Answer the following question: " + query
else:
return query
narrative_direction = state.get("narrator__narrative_direction")
if state.get("narrator__sensory_narration") and narrative_direction:
return "Collect information that aids in describing the following sensory experience: " + narrative_direction
return (
"Collect information that aids in describing the following sensory experience: "
+ narrative_direction
)
if state.get("narrator__visual_narration") and narrative_direction:
return "Collect information that aids in describing the following visual experience: " + narrative_direction
return (
"Collect information that aids in describing the following visual experience: "
+ narrative_direction
)
if state.get("narrator__fn_narrate_character_entry") and narrative_direction:
return "Collect information that aids in describing the following character entry: " + narrative_direction
return (
"Collect information that aids in describing the following character entry: "
+ narrative_direction
)
if state.get("narrator__fn_narrate_character_exit") and narrative_direction:
return "Collect information that aids in describing the following character exit: " + narrative_direction
return (
"Collect information that aids in describing the following character exit: "
+ narrative_direction
)
if state.get("narrator__fn_narrate_progress"):
return "Collect information that aids in progressing the story: " + narrative_direction
return (
"Collect information that aids in progressing the story: "
+ narrative_direction
)
return ""
# actions
@set_processing
async def analyze_scene_for_next_action(self, typ:str, character:"Character"=None, length:int=1024) -> str:
async def analyze_scene_for_next_action(
self, typ: str, character: "Character" = None, length: int = 1024
) -> str:
"""
Analyzes the current scene progress and gives a suggestion for the next action.
taken by the given actor.
"""
# deep analysis is only available if the scene has a layered history
# and context investigation is enabled
deep_analysis = (self.deep_analysis and self.context_investigation_available)
deep_analysis = self.deep_analysis and self.context_investigation_available
analysis_sub_type = await self.analyze_scene_sub_type(typ)
template_vars = {
"max_tokens": self.client.max_token_length,
"scene": self.scene,
@@ -381,58 +407,60 @@ class SceneAnalyzationMixin:
"analysis_type": typ,
"analysis_sub_type": analysis_sub_type,
}
emission = SceneAnalysisEmission(agent=self, template_vars=template_vars, analysis_type=typ)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before").send(
emission
emission = SceneAnalysisEmission(
agent=self, template_vars=template_vars, analysis_type=typ
)
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.before"
).send(emission)
template_vars["dynamic_instructions"] = emission.dynamic_instructions
response = await Prompt.request(
f"summarizer.analyze-scene-for-next-{typ}",
self.client,
f"investigate_{length}",
vars=template_vars,
)
response = strip_partial_sentences(response)
if not response.strip():
return response
if deep_analysis:
emission = SceneAnalysisDeepAnalysisEmission(
agent=self,
agent=self,
analysis=response,
analysis_type=typ,
analysis_sub_type=analysis_sub_type,
character=character,
max_content_investigations=self.deep_analysis_max_context_investigations
max_content_investigations=self.deep_analysis_max_context_investigations,
)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").send(
emission
)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after_deep_analysis").send(
emission
)
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").send(
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.before_deep_analysis"
).send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.after_deep_analysis"
).send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.after"
).send(
SceneAnalysisEmission(
agent=self,
template_vars=template_vars,
response=response,
agent=self,
template_vars=template_vars,
response=response,
analysis_type=typ,
dynamic_instructions=emission.dynamic_instructions
dynamic_instructions=emission.dynamic_instructions,
)
)
self.set_context_states(scene_analysis=response)
self.set_scene_states(scene_analysis=response)
return response
return response

View File

@@ -22,14 +22,12 @@ if TYPE_CHECKING:
log = structlog.get_logger()
class ContextInvestigationMixin:
"""
Summarizer agent mixin that provides functionality for context investigation
through the layered history of the scene.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["context_investigation"] = AgentAction(
@@ -52,7 +50,7 @@ class ContextInvestigationMixin:
{"label": "Short (256)", "value": "256"},
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"},
]
],
),
"update_method": AgentActionConfig(
type="text",
@@ -62,57 +60,56 @@ class ContextInvestigationMixin:
choices=[
{"label": "Replace", "value": "replace"},
{"label": "Smart Merge", "value": "merge"},
]
)
}
],
),
},
)
# config property helpers
@property
def context_investigation_enabled(self):
return self.actions["context_investigation"].enabled
@property
def context_investigation_available(self):
return (
self.context_investigation_enabled and
self.layered_history_available
)
return self.context_investigation_enabled and self.layered_history_available
@property
def context_investigation_answer_length(self) -> int:
return int(self.actions["context_investigation"].config["answer_length"].value)
@property
def context_investigation_update_method(self) -> str:
return self.actions["context_investigation"].config["update_method"].value
# signal connect
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
self.on_inject_context_investigation
)
talemate.emit.async_signals.get(
"agent.conversation.inject_instructions"
).connect(self.on_inject_context_investigation)
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
self.on_inject_context_investigation
)
talemate.emit.async_signals.get("agent.director.guide.inject_instructions").connect(
self.on_inject_context_investigation
)
talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").connect(
self.on_summarization_scene_analysis_before_deep_analysis
)
async def on_summarization_scene_analysis_before_deep_analysis(self, emission:SceneAnalysisDeepAnalysisEmission):
talemate.emit.async_signals.get(
"agent.director.guide.inject_instructions"
).connect(self.on_inject_context_investigation)
talemate.emit.async_signals.get(
"agent.summarization.scene_analysis.before_deep_analysis"
).connect(self.on_summarization_scene_analysis_before_deep_analysis)
async def on_summarization_scene_analysis_before_deep_analysis(
self, emission: SceneAnalysisDeepAnalysisEmission
):
"""
Handles context investigation for deep scene analysis.
"""
if not self.context_investigation_enabled:
return
suggested_investigations = await self.suggest_context_investigations(
emission.analysis,
emission.analysis_type,
@@ -120,67 +117,72 @@ class ContextInvestigationMixin:
max_calls=emission.max_content_investigations,
character=emission.character,
)
response = emission.analysis
ci_calls:list[focal.Call] = await self.request_context_investigations(
suggested_investigations,
max_calls=emission.max_content_investigations
ci_calls: list[focal.Call] = await self.request_context_investigations(
suggested_investigations, max_calls=emission.max_content_investigations
)
log.debug("analyze_scene_for_next_action", ci_calls=ci_calls)
# append call queries and answers to the response
ci_text = []
for ci_call in ci_calls:
try:
ci_text.append(f"{ci_call.arguments['query']}\n{ci_call.result}")
except KeyError as e:
log.error("analyze_scene_for_next_action", error="Missing key in call", ci_call=ci_call)
context_investigation="\n\n".join(ci_text if ci_text else [])
except KeyError:
log.error(
"analyze_scene_for_next_action",
error="Missing key in call",
ci_call=ci_call,
)
context_investigation = "\n\n".join(ci_text if ci_text else [])
current_context_investigation = self.get_scene_state("context_investigation")
if current_context_investigation and context_investigation:
if self.context_investigation_update_method == "merge":
context_investigation = await self.update_context_investigation(
current_context_investigation, context_investigation, response
)
self.set_scene_states(context_investigation=context_investigation)
self.set_context_states(context_investigation=context_investigation)
async def on_inject_context_investigation(self, emission:ConversationAgentEmission | NarratorAgentEmission):
async def on_inject_context_investigation(
self, emission: ConversationAgentEmission | NarratorAgentEmission
):
"""
Injects context investigation into the conversation.
"""
if not self.context_investigation_enabled:
return
context_investigation = self.get_scene_state("context_investigation")
log.debug("summarizer.on_inject_context_investigation", context_investigation=context_investigation, emission=emission)
log.debug(
"summarizer.on_inject_context_investigation",
context_investigation=context_investigation,
emission=emission,
)
if context_investigation:
emission.dynamic_instructions.append(
DynamicInstruction(
title="Context Investigation",
content=context_investigation
title="Context Investigation", content=context_investigation
)
)
# methods
@set_processing
async def suggest_context_investigations(
self,
analysis:str,
analysis_type:str,
analysis_sub_type:str="",
max_calls:int=3,
character:"Character"=None,
analysis: str,
analysis_type: str,
analysis_sub_type: str = "",
max_calls: int = 3,
character: "Character" = None,
) -> str:
template_vars = {
"max_tokens": self.client.max_token_length,
"scene": self.scene,
@@ -192,111 +194,119 @@ class ContextInvestigationMixin:
"analysis_type": analysis_type,
"analysis_sub_type": analysis_sub_type,
}
if not analysis_sub_type:
template = f"summarizer.suggest-context-investigations-for-{analysis_type}"
else:
template = f"summarizer.suggest-context-investigations-for-{analysis_type}-{analysis_sub_type}"
log.debug("summarizer.suggest_context_investigations", template=template, template_vars=template_vars)
log.debug(
"summarizer.suggest_context_investigations",
template=template,
template_vars=template_vars,
)
response = await Prompt.request(
template,
self.client,
"investigate_512",
vars=template_vars,
)
return response.strip()
@set_processing
async def investigate_context(
self,
layer:int,
index:int,
query:str,
analysis:str="",
max_calls:int=3,
pad_entries:int=5,
self,
layer: int,
index: int,
query: str,
analysis: str = "",
max_calls: int = 3,
pad_entries: int = 5,
) -> str:
"""
Processes a context investigation.
Arguments:
- layer: The layer to investigate
- index: The index in the layer to investigate
- query: The query to investigate
- analysis: Scene analysis text
- pad_entries: if > 0 will pad the entries with the given number of entries before and after the start and end index
"""
log.debug("summarizer.investigate_context", layer=layer, index=index, query=query)
log.debug(
"summarizer.investigate_context", layer=layer, index=index, query=query
)
entry = self.scene.layered_history[layer][index]
layer_to_investigate = layer - 1
start = max(entry["start"] - pad_entries, 0)
end = entry["end"] + pad_entries + 1
if layer_to_investigate == -1:
entries = self.scene.archived_history[start:end]
else:
entries = self.scene.layered_history[layer_to_investigate][start:end]
async def answer(query:str, instructions:str) -> str:
log.debug("Answering context investigation", query=query, instructions=answer)
async def answer(query: str, instructions: str) -> str:
log.debug(
"Answering context investigation", query=query, instructions=answer
)
world_state = get_agent("world_state")
return await world_state.analyze_history_and_follow_instructions(
entries,
f"{query}\n{instructions}",
analysis=analysis,
response_length=self.context_investigation_answer_length
response_length=self.context_investigation_answer_length,
)
async def investigate_context(chapter_number:str, query:str) -> str:
async def investigate_context(chapter_number: str, query: str) -> str:
# look for \d.\d in the chapter number, extract as layer and index
match = re.match(r"(\d+)\.(\d+)", chapter_number)
if not match:
log.error("summarizer.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
log.error(
"summarizer.investigate_context",
error="Invalid chapter number",
chapter_number=chapter_number,
)
return ""
layer = int(match.group(1))
index = int(match.group(2))
return await self.investigate_context(layer-1, index-1, query, analysis=analysis, max_calls=max_calls)
return await self.investigate_context(
layer - 1, index - 1, query, analysis=analysis, max_calls=max_calls
)
async def abort():
log.debug("Aborting context investigation")
focal_handler: focal.Focal = focal.Focal(
self.client,
callbacks=[
focal.Callback(
name="investigate_context",
arguments = [
arguments=[
focal.Argument(name="chapter_number", type="str"),
focal.Argument(name="query", type="str")
focal.Argument(name="query", type="str"),
],
fn=investigate_context
fn=investigate_context,
),
focal.Callback(
name="answer",
arguments = [
arguments=[
focal.Argument(name="instructions", type="str"),
focal.Argument(name="query", type="str")
focal.Argument(name="query", type="str"),
],
fn=answer
fn=answer,
),
focal.Callback(
name="abort",
fn=abort
)
focal.Callback(name="abort", fn=abort),
],
max_calls=max_calls,
scene=self.scene,
@@ -307,84 +317,86 @@ class ContextInvestigationMixin:
entries=entries,
analysis=analysis,
)
await focal_handler.request(
"summarizer.investigate-context",
)
log.debug("summarizer.investigate_context", calls=focal_handler.state.calls)
return focal_handler.state.calls
return focal_handler.state.calls
@set_processing
async def request_context_investigations(
self,
analysis:str,
max_calls:int=3,
self,
analysis: str,
max_calls: int = 3,
) -> list[focal.Call]:
"""
Requests context investigations for the given analysis.
"""
async def abort():
log.debug("Aborting context investigations")
async def investigate_context(chapter_number:str, query:str) -> str:
async def investigate_context(chapter_number: str, query: str) -> str:
# look for \d.\d in the chapter number, extract as layer and index
match = re.match(r"(\d+)\.(\d+)", chapter_number)
if not match:
log.error("summarizer.request_context_investigations.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
log.error(
"summarizer.request_context_investigations.investigate_context",
error="Invalid chapter number",
chapter_number=chapter_number,
)
return ""
layer = int(match.group(1))
index = int(match.group(2))
num_layers = len(self.scene.layered_history)
return await self.investigate_context(num_layers - layer, index-1, query, analysis, max_calls=max_calls)
return await self.investigate_context(
num_layers - layer, index - 1, query, analysis, max_calls=max_calls
)
focal_handler: focal.Focal = focal.Focal(
self.client,
callbacks=[
focal.Callback(
name="investigate_context",
arguments = [
arguments=[
focal.Argument(name="chapter_number", type="str"),
focal.Argument(name="query", type="str")
focal.Argument(name="query", type="str"),
],
fn=investigate_context
fn=investigate_context,
),
focal.Callback(
name="abort",
fn=abort
)
focal.Callback(name="abort", fn=abort),
],
max_calls=max_calls,
scene=self.scene,
text=analysis
text=analysis,
)
await focal_handler.request(
"summarizer.request-context-investigation",
)
log.debug("summarizer.request_context_investigations", calls=focal_handler.state.calls)
return focal.collect_calls(
focal_handler.state.calls,
nested=True,
filter=lambda c: c.name == "answer"
log.debug(
"summarizer.request_context_investigations", calls=focal_handler.state.calls
)
# return focal_handler.state.calls
return focal.collect_calls(
focal_handler.state.calls, nested=True, filter=lambda c: c.name == "answer"
)
# return focal_handler.state.calls
@set_processing
async def update_context_investigation(
self,
current_context_investigation:str,
new_context_investigation:str,
analysis:str,
current_context_investigation: str,
new_context_investigation: str,
analysis: str,
):
response = await Prompt.request(
"summarizer.update-context-investigation",
@@ -398,5 +410,5 @@ class ContextInvestigationMixin:
"max_tokens": self.client.max_token_length,
},
)
return response.strip()
return response.strip()

View File

@@ -1,5 +1,5 @@
import structlog
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
from talemate.agents.base import (
set_processing,
AgentAction,
@@ -24,11 +24,12 @@ talemate.emit.async_signals.register(
"agent.summarization.layered_history.finalize",
)
@dataclasses.dataclass
class LayeredHistoryFinalizeEmission(AgentEmission):
entry: LayeredArchiveEntry | None = None
summarization_history: list[str] = dataclasses.field(default_factory=lambda: [])
@property
def response(self) -> str | None:
return self.entry.text if self.entry else None
@@ -38,22 +39,23 @@ class LayeredHistoryFinalizeEmission(AgentEmission):
if self.entry:
self.entry.text = value
class SummaryLongerThanOriginalError(ValueError):
def __init__(self, original_length:int, summarized_length:int):
def __init__(self, original_length: int, summarized_length: int):
self.original_length = original_length
self.summarized_length = summarized_length
super().__init__(f"Summarized text is longer than original text: {summarized_length} > {original_length}")
super().__init__(
f"Summarized text is longer than original text: {summarized_length} > {original_length}"
)
class LayeredHistoryMixin:
"""
Summarizer agent mixin that provides functionality for maintaining a layered history.
"""
@classmethod
def add_actions(cls, actions: dict[str, AgentAction]):
actions["layered_history"] = AgentAction(
enabled=True,
container=True,
@@ -80,7 +82,7 @@ class LayeredHistoryMixin:
max=5,
step=1,
value=3,
),
),
"max_process_tokens": AgentActionConfig(
type="number",
label="Maximum tokens to process",
@@ -116,69 +118,71 @@ class LayeredHistoryMixin:
{"label": "Medium (512)", "value": "512"},
{"label": "Long (1024)", "value": "1024"},
{"label": "Exhaustive (2048)", "value": "2048"},
]
],
),
},
)
# config property helpers
@property
def layered_history_enabled(self):
return self.actions["layered_history"].enabled
@property
def layered_history_threshold(self):
return self.actions["layered_history"].config["threshold"].value
@property
def layered_history_max_process_tokens(self):
return self.actions["layered_history"].config["max_process_tokens"].value
@property
def layered_history_max_layers(self):
return self.actions["layered_history"].config["max_layers"].value
@property
def layered_history_chunk_size(self) -> int:
return self.actions["layered_history"].config["chunk_size"].value
@property
def layered_history_analyze_chunks(self) -> bool:
return self.actions["layered_history"].config["analyze_chunks"].value
@property
def layered_history_response_length(self) -> int:
return int(self.actions["layered_history"].config["response_length"].value)
@property
def layered_history_available(self):
return self.layered_history_enabled and self.scene.layered_history and self.scene.layered_history[0]
return (
self.layered_history_enabled
and self.scene.layered_history
and self.scene.layered_history[0]
)
# signals
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.summarization.after_build_archive").connect(
self.on_after_build_archive
)
async def on_after_build_archive(self, emission:"BuildArchiveEmission"):
talemate.emit.async_signals.get(
"agent.summarization.after_build_archive"
).connect(self.on_after_build_archive)
async def on_after_build_archive(self, emission: "BuildArchiveEmission"):
"""
After the archive has been built, we will update the layered history.
"""
if self.layered_history_enabled:
await self.summarize_to_layered_history(
generation_options=emission.generation_options
)
# helpers
async def _lh_split_and_summarize_chunks(
self,
self,
chunks: list[dict],
extra_context: str,
generation_options: GenerationOptions | None = None,
@@ -189,21 +193,29 @@ class LayeredHistoryMixin:
"""
summaries = []
current_chunk = chunks.copy()
while current_chunk:
partial_chunk = []
max_process_tokens = self.layered_history_max_process_tokens
# Build partial chunk up to max_process_tokens
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
while (
current_chunk
and util.count_tokens(
"\n\n".join(chunk["text"] for chunk in partial_chunk)
)
< max_process_tokens
):
partial_chunk.append(current_chunk.pop(0))
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
log.debug("_split_and_summarize_chunks",
tokens_in_chunk=util.count_tokens(text_to_summarize),
max_process_tokens=max_process_tokens)
text_to_summarize = "\n\n".join(chunk["text"] for chunk in partial_chunk)
log.debug(
"_split_and_summarize_chunks",
tokens_in_chunk=util.count_tokens(text_to_summarize),
max_process_tokens=max_process_tokens,
)
summary_text = await self.summarize_events(
text_to_summarize,
extra_context=extra_context + "\n\n".join(summaries),
@@ -213,9 +225,9 @@ class LayeredHistoryMixin:
chunk_size=self.layered_history_chunk_size,
)
summaries.append(summary_text)
return summaries
def _lh_validate_summary_length(self, summaries: list[str], original_length: int):
"""
Validates that the summarized text is not longer than the original.
@@ -224,17 +236,19 @@ class LayeredHistoryMixin:
summarized_length = util.count_tokens(summaries)
if summarized_length > original_length:
raise SummaryLongerThanOriginalError(original_length, summarized_length)
log.debug("_validate_summary_length",
original_length=original_length,
summarized_length=summarized_length)
log.debug(
"_validate_summary_length",
original_length=original_length,
summarized_length=summarized_length,
)
def _lh_build_extra_context(self, layer_index: int) -> str:
"""
Builds extra context from compiled layered history for the given layer.
"""
return "\n\n".join(self.compile_layered_history(layer_index))
def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]:
"""
Extracts timestamps from a chunk of entries.
@@ -242,144 +256,156 @@ class LayeredHistoryMixin:
"""
if not chunk:
return "PT1S", "PT1S", "PT1S"
ts = chunk[0].get('ts', 'PT1S')
ts_start = chunk[0].get('ts_start', ts)
ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts))
ts = chunk[0].get("ts", "PT1S")
ts_start = chunk[0].get("ts_start", ts)
ts_end = chunk[-1].get("ts_end", chunk[-1].get("ts", ts))
return ts, ts_start, ts_end
async def _lh_finalize_archive_entry(
self,
self,
entry: LayeredArchiveEntry,
summarization_history: list[str] | None = None,
) -> LayeredArchiveEntry:
"""
Finalizes an archive entry by summarizing it and adding it to the layered history.
"""
emission = LayeredHistoryFinalizeEmission(
agent=self,
entry=entry,
summarization_history=summarization_history,
)
await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission)
await talemate.emit.async_signals.get(
"agent.summarization.layered_history.finalize"
).send(emission)
return emission.entry
# methods
def compile_layered_history(
self,
for_layer_index:int = None,
as_objects:bool=False,
include_base_layer:bool=False,
max:int = None,
self,
for_layer_index: int = None,
as_objects: bool = False,
include_base_layer: bool = False,
max: int = None,
base_layer_end_id: str | None = None,
) -> list[str]:
"""
Starts at the last layer and compiles the layered history into a single
list of events.
We are iterating backwards, so the last layer will be the most granular.
Each preceeding layer starts from the end of the the next layer.
"""
layered_history = self.scene.layered_history
compiled = []
next_layer_start = None
len_layered_history = len(layered_history)
for i in range(len_layered_history - 1, -1, -1):
if for_layer_index is not None:
if i < for_layer_index:
break
log.debug("compilelayered history", i=i, next_layer_start=next_layer_start)
if not layered_history[i]:
continue
entry_num = 1
for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]:
for layered_history_entry in layered_history[i][
next_layer_start if next_layer_start is not None else 0 :
]:
if base_layer_end_id:
contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry(
index=0,
layer=i+1,
**layered_history_entry)
contained = entry_contained(
self.scene,
base_layer_end_id,
HistoryEntry(index=0, layer=i + 1, **layered_history_entry),
)
if contained:
log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id)
log.debug(
"compile_layered_history",
contained=True,
base_layer_end_id=base_layer_end_id,
)
break
text = f"{layered_history_entry['text']}"
if for_layer_index == i and max is not None and max <= layered_history_entry["end"]:
if (
for_layer_index == i
and max is not None
and max <= layered_history_entry["end"]
):
break
if as_objects:
compiled.append({
"text": text,
"start": layered_history_entry["start"],
"end": layered_history_entry["end"],
"layer": i,
"layer_r": len_layered_history - i,
"ts_start": layered_history_entry["ts_start"],
"index": entry_num,
})
compiled.append(
{
"text": text,
"start": layered_history_entry["start"],
"end": layered_history_entry["end"],
"layer": i,
"layer_r": len_layered_history - i,
"ts_start": layered_history_entry["ts_start"],
"index": entry_num,
}
)
entry_num += 1
else:
compiled.append(text)
next_layer_start = layered_history_entry["end"] + 1
if i == 0 and include_base_layer:
# we are are at layered history layer zero and inclusion of base layer (archived history) is requested
# so we append the base layer to the compiled list, starting from
# index `next_layer_start`
entry_num = 1
for ah in self.scene.archived_history[next_layer_start or 0:]:
for ah in self.scene.archived_history[next_layer_start or 0 :]:
if base_layer_end_id and ah["id"] == base_layer_end_id:
break
text = f"{ah['text']}"
if as_objects:
compiled.append({
"text": text,
"start": ah["start"],
"end": ah["end"],
"layer": -1,
"layer_r": 1,
"ts": ah["ts"],
"index": entry_num,
})
compiled.append(
{
"text": text,
"start": ah["start"],
"end": ah["end"],
"layer": -1,
"layer_r": 1,
"ts": ah["ts"],
"index": entry_num,
}
)
entry_num += 1
else:
compiled.append(text)
return compiled
@set_processing
async def summarize_to_layered_history(self, generation_options: GenerationOptions | None = None):
async def summarize_to_layered_history(
self, generation_options: GenerationOptions | None = None
):
"""
The layered history is a summarized archive with dynamic layers that
will get less and less granular as the scene progresses.
The most granular is still self.scene.archived_history, which holds
all the base layer summarizations.
self.scene.layered_history = [
# first layer after archived_history
[
@@ -391,7 +417,7 @@ class LayeredHistoryMixin:
},
...
],
# second layer
[
{
@@ -402,29 +428,29 @@ class LayeredHistoryMixin:
},
...
],
# additional layers
...
]
The same token threshold as for the base layer will be used for the
layers.
The same summarization function will be used for the layers.
The next level layer will be generated automatically when the token
threshold is reached.
"""
if not self.scene.archived_history:
return # No base layer summaries to work with
token_threshold = self.layered_history_threshold
max_layers = self.layered_history_max_layers
if not hasattr(self.scene, 'layered_history'):
if not hasattr(self.scene, "layered_history"):
self.scene.layered_history = []
layered_history = self.scene.layered_history
async def summarize_layer(source_layer, next_layer_index, start_from) -> bool:
@@ -432,147 +458,192 @@ class LayeredHistoryMixin:
current_tokens = 0
start_index = start_from
noop = True
total_tokens_in_previous_layer = util.count_tokens([
entry['text'] for entry in source_layer
])
total_tokens_in_previous_layer = util.count_tokens(
[entry["text"] for entry in source_layer]
)
estimated_entries = total_tokens_in_previous_layer // token_threshold
for i in range(start_from, len(source_layer)):
entry = source_layer[i]
entry_tokens = util.count_tokens(entry['text'])
log.debug("summarize_to_layered_history", entry=entry["text"][:100]+"...", tokens=entry_tokens, current_layer=next_layer_index-1)
entry_tokens = util.count_tokens(entry["text"])
log.debug(
"summarize_to_layered_history",
entry=entry["text"][:100] + "...",
tokens=entry_tokens,
current_layer=next_layer_index - 1,
)
if current_tokens + entry_tokens > token_threshold:
if current_chunk:
try:
# check if the next layer exists
next_layer = layered_history[next_layer_index]
except IndexError:
# create the next layer
layered_history.append([])
log.debug("summarize_to_layered_history", created_layer=next_layer_index)
log.debug(
"summarize_to_layered_history",
created_layer=next_layer_index,
)
next_layer = layered_history[next_layer_index]
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
ts, ts_start, ts_end = self._lh_extract_timestamps(
current_chunk
)
extra_context = self._lh_build_extra_context(next_layer_index)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
text_length = util.count_tokens(
"\n\n".join(chunk["text"] for chunk in current_chunk)
)
num_entries_in_layer = len(layered_history[next_layer_index])
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True})
emit(
"status",
status="busy",
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}",
data={"cancellable": True},
)
summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
noop = False
# validate summary length
self._lh_validate_summary_length(summaries, text_length)
next_layer.append(LayeredArchiveEntry(**{
"start": start_index,
"end": i,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
}).model_dump(exclude_none=True))
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}")
next_layer.append(
LayeredArchiveEntry(
**{
"start": start_index,
"end": i,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
}
).model_dump(exclude_none=True)
)
emit(
"status",
status="busy",
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer + 1} / {estimated_entries}",
)
current_chunk = []
current_tokens = 0
start_index = i
current_chunk.append(entry)
current_tokens += entry_tokens
log.debug("summarize_to_layered_history", tokens=current_tokens, threshold=token_threshold, next_layer=next_layer_index)
log.debug(
"summarize_to_layered_history",
tokens=current_tokens,
threshold=token_threshold,
next_layer=next_layer_index,
)
return not noop
# First layer (always the base layer)
has_been_updated = False
try:
if not layered_history:
layered_history.append([])
log.debug("summarize_to_layered_history", layer="base", new_layer=True)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
has_been_updated = await summarize_layer(
self.scene.archived_history, 0, 0
)
elif layered_history[0]:
# determine starting point by checking for `end` in the last entry
last_entry = layered_history[0][-1]
end = last_entry["end"]
log.debug("summarize_to_layered_history", layer="base", start=end)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end)
has_been_updated = await summarize_layer(
self.scene.archived_history, 0, end
)
else:
log.debug("summarize_to_layered_history", layer="base", empty=True)
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
has_been_updated = await summarize_layer(
self.scene.archived_history, 0, 0
)
except SummaryLongerThanOriginalError as exc:
log.error("summarize_to_layered_history", error=exc, layer="base")
emit("status", status="error", message="Layered history update failed.")
return
except GenerationCancelled as e:
log.info("Generation cancelled, stopping rebuild of historical layered history")
emit("status", message="Rebuilding of layered history cancelled", status="info")
log.info(
"Generation cancelled, stopping rebuild of historical layered history"
)
emit(
"status",
message="Rebuilding of layered history cancelled",
status="info",
)
handle_generation_cancelled(e)
return
# process layers
async def update_layers() -> bool:
noop = True
for index in range(0, len(layered_history)):
# check against max layers
if index + 1 > max_layers:
return False
try:
# check if the next layer exists
next_layer = layered_history[index + 1]
except IndexError:
next_layer = None
end = next_layer[-1]["end"] if next_layer else 0
log.debug("summarize_to_layered_history", layer=index, start=end)
summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0)
summarized = await summarize_layer(
layered_history[index], index + 1, end if end else 0
)
if summarized:
noop = False
return not noop
try:
while await update_layers():
has_been_updated = True
if has_been_updated:
emit("status", status="success", message="Layered history updated.")
except SummaryLongerThanOriginalError as exc:
log.error("summarize_to_layered_history", error=exc, layer="subsequent")
emit("status", status="error", message="Layered history update failed.")
return
except GenerationCancelled as e:
log.info("Generation cancelled, stopping rebuild of historical layered history")
emit("status", message="Rebuilding of layered history cancelled", status="info")
log.info(
"Generation cancelled, stopping rebuild of historical layered history"
)
emit(
"status",
message="Rebuilding of layered history cancelled",
status="info",
)
handle_generation_cancelled(e)
return
async def summarize_entries_to_layered_history(
self,
entries: list[dict],
self,
entries: list[dict],
next_layer_index: int,
start_index: int,
end_index: int,
@@ -580,11 +651,11 @@ class LayeredHistoryMixin:
) -> list[LayeredArchiveEntry]:
"""
Summarizes a list of entries into layered history entries.
This method is used for regenerating specific history entries by processing
their source entries. It chunks the entries based on the token threshold and
summarizes each chunk into a LayeredArchiveEntry.
Args:
entries: List of dictionaries containing the text entries to summarize.
Each entry should have at least a 'text' field and optionally
@@ -597,12 +668,12 @@ class LayeredHistoryMixin:
correspond to.
generation_options: Optional generation options to pass to the summarization
process.
Returns:
List of LayeredArchiveEntry objects containing the summarized text along
with timestamp and index information. Currently returns a list with a
single entry, but the structure supports multiple entries if needed.
Notes:
- The method respects the layered_history_threshold for chunking
- Uses helper methods for timestamp extraction, context building, and
@@ -611,63 +682,73 @@ class LayeredHistoryMixin:
- The last entry is always included in the final chunk if it doesn't
exceed the token threshold
"""
token_threshold = self.layered_history_threshold
archive_entries = []
summaries = []
current_chunk = []
current_tokens = 0
ts = "PT1S"
ts_start = "PT1S"
ts_end = "PT1S"
for entry_index, entry in enumerate(entries):
is_last_entry = entry_index == len(entries) - 1
entry_tokens = util.count_tokens(entry['text'])
log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens)
entry_tokens = util.count_tokens(entry["text"])
log.debug(
"summarize_entries_to_layered_history",
entry=entry["text"][:100] + "...",
entry_tokens=entry_tokens,
current_layer=next_layer_index - 1,
current_tokens=current_tokens,
)
if current_tokens + entry_tokens > token_threshold or is_last_entry:
if is_last_entry and current_tokens + entry_tokens <= token_threshold:
# if we are here because this is the last entry and adding it to
# the current chunk would not exceed the token threshold, we will
# add it to the current chunk
current_chunk.append(entry)
if current_chunk:
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
extra_context = self._lh_build_extra_context(next_layer_index)
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
text_length = util.count_tokens(
"\n\n".join(chunk["text"] for chunk in current_chunk)
)
summaries = await self._lh_split_and_summarize_chunks(
current_chunk,
extra_context,
generation_options=generation_options,
)
# validate summary length
self._lh_validate_summary_length(summaries, text_length)
archive_entry = LayeredArchiveEntry(**{
"start": start_index,
"end": end_index,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
})
archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n"))
archive_entry = LayeredArchiveEntry(
**{
"start": start_index,
"end": end_index,
"ts": ts,
"ts_start": ts_start,
"ts_end": ts_end,
"text": "\n\n".join(summaries),
}
)
archive_entry = await self._lh_finalize_archive_entry(
archive_entry, extra_context.split("\n\n")
)
archive_entries.append(archive_entry)
current_chunk.append(entry)
current_tokens += entry_tokens
return archive_entries

View File

@@ -51,11 +51,10 @@ if not TTS:
def parse_chunks(text: str) -> list[str]:
"""
Takes a string and splits it into chunks based on punctuation.
In case of an error it will return the original text as a single chunk and
In case of an error it will return the original text as a single chunk and
the error will be logged.
"""
@@ -278,7 +277,6 @@ class TTSAgent(Agent):
@property
def agent_details(self):
details = {
"api": AgentDetail(
icon="mdi-server-outline",
@@ -645,7 +643,6 @@ class TTSAgent(Agent):
# OPENAI
async def _generate_openai(self, text: str, chunk_size: int = 1024):
client = AsyncOpenAI(api_key=self.openai_api_key)
model = self.actions["openai"].config["model"].value

View File

@@ -1,15 +1,13 @@
import asyncio
import traceback
import structlog
import talemate.agents.visual.automatic1111
import talemate.agents.visual.comfyui
import talemate.agents.visual.openai_image
import talemate.agents.visual.automatic1111 # noqa: F401
import talemate.agents.visual.comfyui # noqa: F401
import talemate.agents.visual.openai_image # noqa: F401
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
@@ -25,11 +23,11 @@ from talemate.prompts.base import Prompt
from .commands import * # noqa
from .context import VIS_TYPES, VisualContext, VisualContextState, visual_context
from .handlers import HANDLERS
from .schema import RESOLUTION_MAP, RenderSettings
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
from .schema import RESOLUTION_MAP
from .style import MAJOR_STYLES, STYLE_MAP, Style
from .websocket_handler import VisualWebsocketHandler
import talemate.agents.visual.nodes
import talemate.agents.visual.nodes # noqa: F401
__all__ = [
"VisualAgent",
@@ -114,7 +112,6 @@ class VisualBase(Agent):
label="Process in Background",
description="Process renders in the background",
),
"_prompts": AgentAction(
enabled=True,
container=True,
@@ -280,7 +277,6 @@ class VisualBase(Agent):
await super().ready_check(task)
async def setup_check(self):
if not self.actions["automatic_setup"].enabled:
return
@@ -289,7 +285,6 @@ class VisualBase(Agent):
await getattr(self.client, f"visual_{backend.lower()}_setup")(self)
async def apply_config(self, *args, **kwargs):
try:
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
except (KeyError, TypeError):
@@ -312,7 +307,6 @@ class VisualBase(Agent):
backend_fn = getattr(self, f"{self.backend.lower()}_apply_config", None)
if backend_fn:
if not backend_changed and was_disabled and self.enabled:
# If the backend has not changed, but the agent was previously disabled
# and is now enabled, we need to trigger the backend apply_config function
@@ -336,7 +330,6 @@ class VisualBase(Agent):
)
def prepare_prompt(self, prompt: str, styles: list[Style] = None) -> Style:
prompt_style = Style()
prompt_style.load(prompt)
@@ -432,9 +425,8 @@ class VisualBase(Agent):
async def generate(
self, format: str = "portrait", prompt: str = None, automatic: bool = False
):
context: VisualContextState = visual_context.get()
context:VisualContextState = visual_context.get()
log.debug("visual generate", context=context)
if automatic and not self.allow_automatic_generation:
@@ -466,7 +458,7 @@ class VisualBase(Agent):
thematic_style = self.default_style
vis_type_styles = self.vis_type_styles(context.vis_type)
prompt:Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
prompt: Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
if context.vis_type == VIS_TYPES.CHARACTER:
prompt.keywords.append("character portrait")
@@ -488,12 +480,16 @@ class VisualBase(Agent):
format = "portrait"
context.format = format
can_generate_image = self.enabled and self.backend_ready
if not context.prompt_only and not can_generate_image:
emit("status", "Visual agent is not ready for image generation, will output prompt instead.", status="warning")
emit(
"status",
"Visual agent is not ready for image generation, will output prompt instead.",
status="warning",
)
# if prompt_only, we don't need to generate an image
# instead we emit a system message with the prompt
if context.prompt_only or not can_generate_image:
@@ -509,20 +505,24 @@ class VisualBase(Agent):
"title": f"Visual Prompt - {context.title}",
"display": "tonal",
"as_markdown": True,
}
},
)
return
if not can_generate_image:
return
# Call the backend specific generate function
backend = self.backend
fn = f"{backend.lower()}_generate"
log.info(
"visual generate", backend=backend, prompt=prompt, format=format, context=context
"visual generate",
backend=backend,
prompt=prompt,
format=format,
context=context,
)
if not hasattr(self, fn):
@@ -538,7 +538,6 @@ class VisualBase(Agent):
@set_processing
async def generate_environment_prompt(self, instructions: str = None):
with RevisionDisabled():
response = await Prompt.request(
"visual.generate-environment-prompt",
@@ -556,7 +555,6 @@ class VisualBase(Agent):
async def generate_character_prompt(
self, character_name: str, instructions: str = None
):
character = self.scene.get_character(character_name)
with RevisionDisabled():

View File

@@ -1,22 +1,15 @@
import base64
import io
import httpx
import structlog
from PIL import Image
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from .handlers import register
from .schema import RenderSettings, Resolution
from .style import STYLE_MAP, Style
from .schema import RenderSettings
from .style import Style
log = structlog.get_logger("talemate.agents.visual.automatic1111")
@@ -58,9 +51,9 @@ SAMPLING_SCHEDULES = [
SAMPLING_SCHEDULES = sorted(SAMPLING_SCHEDULES, key=lambda x: x["label"])
@register(backend_name="automatic1111", label="AUTOMATIC1111")
class Automatic1111Mixin:
automatic1111_default_render_settings = RenderSettings()
EXTEND_ACTIONS = {
@@ -139,14 +132,14 @@ class Automatic1111Mixin:
@property
def automatic1111_sampling_method(self):
return self.actions["automatic1111"].config["sampling_method"].value
@property
def automatic1111_schedule_type(self):
return self.actions["automatic1111"].config["schedule_type"].value
@property
def automatic1111_cfg(self):
return self.actions["automatic1111"].config["cfg"].value
return self.actions["automatic1111"].config["cfg"].value
async def automatic1111_generate(self, prompt: Style, format: str):
url = self.api_url
@@ -162,14 +155,16 @@ class Automatic1111Mixin:
"height": resolution.height,
"cfg_scale": self.automatic1111_cfg,
"sampler_name": self.automatic1111_sampling_method,
"scheduler": self.automatic1111_schedule_type
"scheduler": self.automatic1111_schedule_type,
}
log.info("automatic1111_generate", payload=payload, url=url)
async with httpx.AsyncClient() as client:
response = await client.post(
url=f"{url}/sdapi/v1/txt2img", json=payload, timeout=self.generate_timeout
url=f"{url}/sdapi/v1/txt2img",
json=payload,
timeout=self.generate_timeout,
)
r = response.json()

View File

@@ -1,6 +1,5 @@
import asyncio
import base64
import io
import json
import os
import random
@@ -10,13 +9,12 @@ import urllib.parse
import httpx
import pydantic
import structlog
from PIL import Image
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
from .handlers import register
from .schema import RenderSettings, Resolution
from .style import STYLE_MAP, Style
from .style import Style
log = structlog.get_logger("talemate.agents.visual.comfyui")
@@ -25,7 +23,6 @@ class Workflow(pydantic.BaseModel):
nodes: dict
def set_resolution(self, resolution: Resolution):
# will collect all latent image nodes
# if there is multiple will look for the one with the
# title "Talemate Resolution"
@@ -55,7 +52,6 @@ class Workflow(pydantic.BaseModel):
latent_image_node["inputs"]["height"] = resolution.height
def set_prompt(self, prompt: str, negative_prompt: str = None):
# will collect all CLIPTextEncode nodes
# if there is multiple will look for the one with the
@@ -79,7 +75,6 @@ class Workflow(pydantic.BaseModel):
negative_prompt_node = None
for node_id, node in self.nodes.items():
if node["class_type"] == "CLIPTextEncode":
if not positive_prompt_node:
positive_prompt_node = node
@@ -102,7 +97,6 @@ class Workflow(pydantic.BaseModel):
negative_prompt_node["inputs"]["text"] = negative_prompt
def set_checkpoint(self, checkpoint: str):
# will collect all CheckpointLoaderSimple nodes
# if there is multiple will look for the one with the
# title "Talemate Load Checkpoint"
@@ -139,7 +133,6 @@ class Workflow(pydantic.BaseModel):
@register(backend_name="comfyui", label="ComfyUI")
class ComfyUIMixin:
comfyui_default_render_settings = RenderSettings()
EXTEND_ACTIONS = {
@@ -287,7 +280,9 @@ class ComfyUIMixin:
log.info("comfyui_generate", payload=payload, url=url)
async with httpx.AsyncClient() as client:
response = await client.post(url=f"{url}/prompt", json=payload, timeout=self.generate_timeout)
response = await client.post(
url=f"{url}/prompt", json=payload, timeout=self.generate_timeout
)
log.info("comfyui_generate", response=response.text)

View File

@@ -30,7 +30,7 @@ class VisualContextState(pydantic.BaseModel):
format: Union[str, None] = None
replace: bool = False
prompt_only: bool = False
@property
def title(self) -> str:
if self.vis_type == VIS_TYPES.ENVIRONMENT:

View File

@@ -7,7 +7,6 @@ HANDLERS = {}
class register:
def __init__(self, backend_name: str, label: str):
self.backend_name = backend_name
self.label = label

View File

@@ -6,23 +6,27 @@ from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
log = structlog.get_logger("talemate.game.engine.nodes.agents.visual")
@register("agents/visual/Settings")
class VisualSettings(AgentSettingsNode):
"""
Base node to render visual agent settings.
"""
_agent_name:ClassVar[str] = "visual"
_agent_name: ClassVar[str] = "visual"
def __init__(self, title="Visual Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/visual/GenerateCharacterPortrait")
class GenerateCharacterPortrait(AgentNode):
"""
Generates a portrait for a character
"""
_agent_name:ClassVar[str] = "visual"
_agent_name: ClassVar[str] = "visual"
class Fields:
instructions = PropertyField(
name="instructions",
@@ -30,33 +34,34 @@ class GenerateCharacterPortrait(AgentNode):
description="instructions for the portrait",
default=UNRESOLVED,
)
def __init__(self, title="Generate Character Portrait", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_input("instructions", socket_type="str", optional=True)
self.set_property("instructions", UNRESOLVED)
self.add_output("state")
self.add_output("character", socket_type="character")
self.add_output("portrait", socket_type="image")
async def run(self, state: GraphState):
character = self.get_input_value("character")
instructions = self.normalized_input_value("instructions")
portrait = await self.agent.generate_character_portrait(
character_name=character.name,
instructions=instructions,
)
self.set_output_values({
"state": self.get_input_value("state"),
"character": character,
"portrait": portrait
})
self.set_output_values(
{
"state": self.get_input_value("state"),
"character": character,
"portrait": portrait,
}
)

View File

@@ -1,31 +1,21 @@
import base64
import io
from urllib.parse import parse_qs, unquote, urlparse
import httpx
import structlog
from openai import AsyncOpenAI
from PIL import Image
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConditional,
AgentActionConfig,
AgentDetail,
set_processing,
)
from .handlers import register
from .schema import RenderSettings, Resolution
from .style import STYLE_MAP, Style
from .style import Style
log = structlog.get_logger("talemate.agents.visual.openai_image")
@register(backend_name="openai_image", label="OpenAI")
class OpenAIImageMixin:
openai_image_default_render_settings = RenderSettings()
EXTEND_ACTIONS = {

View File

@@ -83,7 +83,9 @@ class Style(pydantic.BaseModel):
# Almost taken straight from some of the fooocus style presets, credit goes to the original author
STYLE_MAP["digital_art"] = Style(
keywords="in the style of a digital artwork, masterpiece, best quality, high detail".split(", "),
keywords="in the style of a digital artwork, masterpiece, best quality, high detail".split(
", "
),
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
)
@@ -95,12 +97,16 @@ STYLE_MAP["concept_art"] = Style(
)
STYLE_MAP["ink_illustration"] = Style(
keywords="in the style of ink illustration, painting, masterpiece, best quality".split(", "),
keywords="in the style of ink illustration, painting, masterpiece, best quality".split(
", "
),
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
)
STYLE_MAP["anime"] = Style(
keywords="in the style of anime, masterpiece, best quality, illustration".split(", "),
keywords="in the style of anime, masterpiece, best quality, illustration".split(
", "
),
negative_keywords="text, watermark, low quality, blurry, photo, 3d".split(", "),
)

View File

@@ -56,7 +56,6 @@ class VisualWebsocketHandler(Plugin):
scene = self.scene
if context and context.character_name:
character = scene.get_character(context.character_name)
if not character:

View File

@@ -21,7 +21,13 @@ from talemate.scene_message import (
from talemate.util.response import extract_list
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
from talemate.agents.base import (
Agent,
AgentAction,
AgentActionConfig,
AgentEmission,
set_processing,
)
from talemate.agents.registry import register
@@ -57,10 +63,7 @@ class TimePassageEmission(WorldStateAgentEmission):
@register()
class WorldStateAgent(
CharacterProgressionMixin,
Agent
):
class WorldStateAgent(CharacterProgressionMixin, Agent):
"""
An agent that handles world state related tasks.
"""
@@ -91,7 +94,7 @@ class WorldStateAgent(
min=1,
max=100,
step=1,
)
),
},
),
"update_reinforcements": AgentAction(
@@ -140,7 +143,7 @@ class WorldStateAgent(
@property
def experimental(self):
return True
@property
def initial_update(self):
return self.actions["update_world_state"].config["initial"].value
@@ -148,7 +151,9 @@ class WorldStateAgent(
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
talemate.emit.async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init_after)
talemate.emit.async_signals.get("scene_loop_init_after").connect(
self.on_scene_loop_init_after
)
async def advance_time(self, duration: str, narrative: str = None):
"""
@@ -183,13 +188,13 @@ class WorldStateAgent(
if not self.initial_update:
return
if self.get_scene_state("inital_update_done"):
return
await self.scene.world_state.request_update()
self.set_scene_states(inital_update_done=True)
async def on_game_loop(self, emission: GameLoopEvent):
"""
Called when a conversation is generated
@@ -339,13 +344,13 @@ class WorldStateAgent(
context = await memory_agent.multi_query(queries, iterate=3)
#log.debug(
# log.debug(
# "analyze_text_and_extract_context_via_queries",
# goal=goal,
# text=text,
# queries=queries,
# context=context,
#)
# )
return context
@@ -356,7 +361,6 @@ class WorldStateAgent(
instruction: str,
short: bool = False,
):
kind = "analyze_freeform_short" if short else "analyze_freeform"
response = await Prompt.request(
@@ -408,21 +412,20 @@ class WorldStateAgent(
)
return response
@set_processing
async def analyze_history_and_follow_instructions(
self,
entries: list[dict],
instructions: str,
analysis: str = "",
response_length: int = 512
response_length: int = 512,
) -> str:
"""
Takes a list of archived_history or layered_history entries
and follows the instructions to generate a response.
"""
response = await Prompt.request(
"world_state.analyze-history-and-follow-instructions",
self.client,
@@ -436,7 +439,7 @@ class WorldStateAgent(
"response_length": response_length,
},
)
return response.strip()
@set_processing
@@ -480,7 +483,7 @@ class WorldStateAgent(
for line in response.split("\n"):
if not line.strip():
continue
if not ":" in line:
if ":" not in line:
break
name, value = line.split(":", 1)
data[name.strip()] = value.strip()
@@ -542,24 +545,33 @@ class WorldStateAgent(
"""
Queries a single re-inforcement
"""
if isinstance(character, self.scene.Character):
character = character.name
message = None
idx, reinforcement = await self.scene.world_state.find_reinforcement(
question, character
)
if not reinforcement:
log.warning(f"Reinforcement not found", question=question, character=character)
log.warning(
"Reinforcement not found", question=question, character=character
)
return
message = ReinforcementMessage(message="")
message.set_source("world_state", "update_reinforcement", question=question, character=character)
message.set_source(
"world_state",
"update_reinforcement",
question=question,
character=character,
)
if reset and reinforcement.insert == "sequential":
self.scene.pop_history(typ="reinforcement", meta_hash=message.meta_hash, all=True)
self.scene.pop_history(
typ="reinforcement", meta_hash=message.meta_hash, all=True
)
if reinforcement.insert == "sequential":
kind = "analyze_freeform_medium_short"
@@ -594,7 +606,7 @@ class WorldStateAgent(
reinforcement.answer = answer
reinforcement.due = reinforcement.interval
# remove any recent previous reinforcement message with same question
# to avoid overloading the near history with reinforcement messages
if not reset:
@@ -818,4 +830,4 @@ class WorldStateAgent(
kwargs=kwargs,
error=e,
)
raise
raise

View File

@@ -1,16 +1,9 @@
from typing import TYPE_CHECKING
import structlog
import re
from talemate.agents.base import (
set_processing,
AgentAction,
AgentActionConfig
)
from talemate.prompts import Prompt
from talemate.agents.base import set_processing, AgentAction, AgentActionConfig
from talemate.instance import get_agent
from talemate.events import GameLoopEvent
from talemate.status import set_loading
from talemate.emit import emit
import talemate.emit.async_signals
import talemate.game.focal as focal
@@ -23,8 +16,8 @@ if TYPE_CHECKING:
log = structlog.get_logger()
class CharacterProgressionMixin:
"""
World-state manager agent mixin that handles tracking of character progression
and proposal of updates to character profiles.
@@ -55,13 +48,13 @@ class CharacterProgressionMixin:
type="bool",
label="Propose as suggestions",
description="Propose changes as suggestions that need to be manually accepted.",
value=True
value=True,
),
"player_character": AgentActionConfig(
type="bool",
label="Player character",
description="Track the player character's progression.",
value=True
value=True,
),
"max_changes": AgentActionConfig(
type="number",
@@ -70,16 +63,16 @@ class CharacterProgressionMixin:
value=1,
min=1,
max=5,
)
}
),
},
)
# config property helpers
# config property helpers
@property
def character_progression_enabled(self) -> bool:
return self.actions["character_progression"].enabled
@property
def character_progression_frequency(self) -> int:
return self.actions["character_progression"].config["frequency"].value
@@ -91,7 +84,7 @@ class CharacterProgressionMixin:
@property
def character_progression_max_changes(self) -> int:
return self.actions["character_progression"].config["max_changes"].value
@property
def character_progression_as_suggestions(self) -> bool:
return self.actions["character_progression"].config["as_suggestions"].value
@@ -100,8 +93,9 @@ class CharacterProgressionMixin:
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop_track_character_progression)
talemate.emit.async_signals.get("game_loop").connect(
self.on_game_loop_track_character_progression
)
async def on_game_loop_track_character_progression(self, emission: GameLoopEvent):
"""
@@ -110,58 +104,69 @@ class CharacterProgressionMixin:
if not self.enabled or not self.character_progression_enabled:
return
log.debug("on_game_loop_track_character_progression", scene=self.scene)
rounds_since_last_check = self.get_scene_state("rounds_since_last_character_progression_check", 0)
rounds_since_last_check = self.get_scene_state(
"rounds_since_last_character_progression_check", 0
)
if rounds_since_last_check < self.character_progression_frequency:
rounds_since_last_check += 1
self.set_scene_states(rounds_since_last_character_progression_check=rounds_since_last_check)
self.set_scene_states(
rounds_since_last_character_progression_check=rounds_since_last_check
)
return
self.set_scene_states(rounds_since_last_character_progression_check=0)
for character in self.scene.characters:
if character.is_player and not self.character_progression_player_character:
continue
calls:list[focal.Call] = await self.determine_character_development(character)
await self.character_progression_process_calls(
character = character,
calls = calls,
as_suggestions = self.character_progression_as_suggestions,
calls: list[focal.Call] = await self.determine_character_development(
character
)
await self.character_progression_process_calls(
character=character,
calls=calls,
as_suggestions=self.character_progression_as_suggestions,
)
# methods
@set_processing
async def character_progression_process_calls(self, character:"Character", calls:list[focal.Call], as_suggestions:bool=True):
world_state_manager:WorldStateManager = self.scene.world_state_manager
async def character_progression_process_calls(
self,
character: "Character",
calls: list[focal.Call],
as_suggestions: bool = True,
):
world_state_manager: WorldStateManager = self.scene.world_state_manager
if as_suggestions:
await world_state_manager.add_suggestion(
Suggestion(
name=character.name,
type="character",
id=f"character-{character.name}",
proposals=calls
proposals=calls,
)
)
else:
for call in calls:
# changes will be applied directly to the character
if call.name in ["add_attribute", "update_attribute"]:
await character.set_base_attribute(call.arguments["name"], call.result)
await character.set_base_attribute(
call.arguments["name"], call.result
)
elif call.name == "remove_attribute":
await character.set_base_attribute(call.arguments["name"], None)
elif call.name == "update_description":
await character.set_description(call.result)
@set_processing
async def determine_character_development(
self,
self,
character: "Character",
generation_options: world_state_templates.GenerationOptions | None = None,
instructions: str = None,
@@ -169,95 +174,96 @@ class CharacterProgressionMixin:
"""
Determine character development
"""
log.debug("determine_character_development", character=character, generation_options=generation_options)
log.debug(
"determine_character_development",
character=character,
generation_options=generation_options,
)
creator = get_agent("creator")
@set_loading("Generating character attribute", cancellable=True)
async def add_attribute(name: str, instructions: str) -> str:
return await creator.generate_character_attribute(
character,
attribute_name = name,
instructions = instructions,
generation_options = generation_options,
attribute_name=name,
instructions=instructions,
generation_options=generation_options,
)
@set_loading("Generating character attribute", cancellable=True)
async def update_attribute(name: str, instructions: str) -> str:
return await creator.generate_character_attribute(
character,
attribute_name = name,
instructions = instructions,
original = character.base_attributes.get(name),
generation_options = generation_options,
attribute_name=name,
instructions=instructions,
original=character.base_attributes.get(name),
generation_options=generation_options,
)
async def remove_attribute(name: str, reason:str) -> str:
async def remove_attribute(name: str, reason: str) -> str:
return None
@set_loading("Generating character description", cancellable=True)
async def update_description(instructions: str) -> str:
return await creator.generate_character_detail(
character,
detail_name = "description",
instructions = instructions,
original = character.description,
detail_name="description",
instructions=instructions,
original=character.description,
length=1024,
generation_options = generation_options,
generation_options=generation_options,
)
focal_handler = focal.Focal(
self.client,
# callbacks
callbacks = [
callbacks=[
focal.Callback(
name = "add_attribute",
arguments = [
name="add_attribute",
arguments=[
focal.Argument(name="name", type="str"),
focal.Argument(name="instructions", type="str"),
],
fn = add_attribute
fn=add_attribute,
),
focal.Callback(
name = "update_attribute",
arguments = [
name="update_attribute",
arguments=[
focal.Argument(name="name", type="str"),
focal.Argument(name="instructions", type="str"),
],
fn = update_attribute
fn=update_attribute,
),
focal.Callback(
name = "remove_attribute",
arguments = [
name="remove_attribute",
arguments=[
focal.Argument(name="name", type="str"),
focal.Argument(name="reason", type="str"),
],
fn = remove_attribute
fn=remove_attribute,
),
focal.Callback(
name = "update_description",
arguments = [
name="update_description",
arguments=[
focal.Argument(name="instructions", type="str"),
],
fn = update_description,
multiple=False
fn=update_description,
multiple=False,
),
],
max_calls = self.character_progression_max_changes,
max_calls=self.character_progression_max_changes,
# context
character = character,
scene = self.scene,
instructions = instructions,
character=character,
scene=self.scene,
instructions=instructions,
)
await focal_handler.request(
"world_state.determine-character-development",
)
log.debug("determine_character_development", calls=focal_handler.state.calls)
return focal_handler.state.calls
return focal_handler.state.calls

View File

@@ -1,7 +1,12 @@
import structlog
from typing import ClassVar, TYPE_CHECKING
from talemate.context import active_scene
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
from talemate.game.engine.nodes.core import (
GraphState,
PropertyField,
UNRESOLVED,
TYPE_CHOICES,
)
from talemate.game.engine.nodes.registry import register
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
from talemate.world_state import InsertionMode
@@ -10,140 +15,140 @@ from talemate.world_state.manager import WorldStateManager
if TYPE_CHECKING:
from talemate.tale_mate import Scene, Character
TYPE_CHOICES.extend([
"world_state/reinforcement",
])
TYPE_CHOICES.extend(
[
"world_state/reinforcement",
]
)
log = structlog.get_logger("talemate.game.engine.nodes.agents.world_state")
@register("agents/world_state/Settings")
class WorldstateSettings(AgentSettingsNode):
"""
Base node to render world_state agent settings.
"""
_agent_name:ClassVar[str] = "world_state"
_agent_name: ClassVar[str] = "world_state"
def __init__(self, title="Worldstate Settings", **kwargs):
super().__init__(title=title, **kwargs)
@register("agents/world_state/ExtractCharacterSheet")
class ExtractCharacterSheet(AgentNode):
"""
Attempts to extract an attribute based character sheet
from a given context for a specific character.
Additionally alteration instructions can be given to
modify the character's existing sheet.
Inputs:
- state: The current state of the graph
- character_name: The name of the character to extract the sheet for
- context: The context to extract the sheet from
- alteration_instructions: Instructions to alter the character's sheet
Outputs:
- character_sheet: The extracted character sheet (dict)
"""
_agent_name:ClassVar[str] = "world_state"
_agent_name: ClassVar[str] = "world_state"
def __init__(self, title="Extract Character Sheet", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character_name", socket_type="str")
self.add_input("context", socket_type="str")
self.add_input("alteration_instructions", socket_type="str", optional=True)
self.add_output("character_sheet", socket_type="dict")
async def run(self, state: GraphState):
context = self.require_input("context")
character_name = self.require_input("character_name")
alteration_instructions = self.get_input_value("alteration_instructions")
sheet = await self.agent.extract_character_sheet(
name = character_name,
text = context,
alteration_instructions = alteration_instructions
name=character_name,
text=context,
alteration_instructions=alteration_instructions,
)
self.set_output_values({
"character_sheet": sheet
})
self.set_output_values({"character_sheet": sheet})
@register("agents/world_state/StateReinforcement")
class StateReinforcement(AgentNode):
"""
Reinforces the a tracked state of a character or the world in general.
Inputs:
- state: The current state of the graph
- query_or_detail: The query or instruction to reinforce
- character: The character to reinforce the state for (optional)
Properties
- reset: If the state should be reset
Outputs:
- state: graph state
- message: state reinforcement message
"""
_agent_name:ClassVar[str] = "world_state"
class Fields:
_agent_name: ClassVar[str] = "world_state"
class Fields:
query_or_detail = PropertyField(
name="query_or_detail",
description="Query or detail to reinforce",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
instructions = PropertyField(
name="instructions",
description="Instructions for the reinforcement",
type="text",
default=""
default="",
)
interval = PropertyField(
name="interval",
description="Interval for reinforcement",
type="int",
default=10,
min=1,
step=1
step=1,
)
insert_method = PropertyField(
name="insert_method",
description="Method to insert reinforcement",
type="str",
default="sequential",
choices=[
mode.value for mode in InsertionMode
]
choices=[mode.value for mode in InsertionMode],
)
reset = PropertyField(
name="reset",
description="If the state should be reset",
type="bool",
default=False
default=False,
)
def __init__(self, title="State Reinforcement", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("query_or_detail", socket_type="str")
@@ -155,20 +160,20 @@ class StateReinforcement(AgentNode):
self.set_property("interval", 10)
self.set_property("insert_method", "sequential")
self.set_property("reset", False)
self.add_output("state")
self.add_output("message", socket_type="message")
self.add_output("reinforcement", socket_type="world_state/reinforcement")
async def run(self, state: GraphState):
scene:"Scene" = active_scene.get()
scene: "Scene" = active_scene.get()
query_or_detail = self.require_input("query_or_detail")
character = self.normalized_input_value("character")
reset = self.get_property("reset")
interval = self.require_number_input("interval")
instructions = self.normalized_input_value("instructions")
insert_method = self.get_property("insert_method")
await scene.world_state.add_reinforcement(
question=query_or_detail,
character=character.name if character else None,
@@ -176,142 +181,128 @@ class StateReinforcement(AgentNode):
interval=interval,
insert=insert_method,
)
message = await self.agent.update_reinforcement(
question=query_or_detail,
character=character,
reset=reset
question=query_or_detail, character=character, reset=reset
)
self.set_output_values({
"state": state,
"message": message
})
self.set_output_values({"state": state, "message": message})
@register("agents/world_state/DeactivateCharacter")
class DeactivateCharacter(AgentNode):
"""
Deactivates a character from the world state.
Inputs:
- state: The current state of the graph
- character: The character to deactivate
Outputs:
- state: The updated state
"""
_agent_name:ClassVar[str] = "world_state"
_agent_name: ClassVar[str] = "world_state"
def __init__(self, title="Deactivate Character", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("character", socket_type="character")
self.add_output("state")
async def run(self, state: GraphState):
character:"Character" = self.require_input("character")
scene:"Scene" = active_scene.get()
manager:WorldStateManager = scene.world_state_manager
character: "Character" = self.require_input("character")
scene: "Scene" = active_scene.get()
manager: WorldStateManager = scene.world_state_manager
manager.deactivate_character(character.name)
self.set_output_values({
"state": state
})
self.set_output_values({"state": state})
@register("agents/world_state/EvaluateQuery")
class EvaluateQuery(AgentNode):
"""
Evaluates a query on the world state.
Inputs:
- state: The current state of the graph
- query: The query to evaluate
- context: The context to evaluate the query in
Outputs:
- state: The current state
- result: The result of the query
"""
_agent_name:ClassVar[str] = "world_state"
_agent_name: ClassVar[str] = "world_state"
class Fields:
query = PropertyField(
name="query",
description="The query to evaluate",
type="str",
default=UNRESOLVED
default=UNRESOLVED,
)
def __init__(self, title="Evaluate Query", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_input("query", socket_type="str")
self.add_input("context", socket_type="str")
self.set_property("query", UNRESOLVED)
self.add_output("state")
self.add_output("result", socket_type="bool")
async def run(self, state: GraphState):
query = self.require_input("query")
context = self.require_input("context")
result = await self.agent.answer_query_true_or_false(
query=query,
text=context
)
self.set_output_values({
"state": state,
"result": result
})
result = await self.agent.answer_query_true_or_false(query=query, text=context)
self.set_output_values({"state": state, "result": result})
@register("agents/world_state/RequestWorldState")
class RequestWorldState(AgentNode):
"""
Requests the current world state.
Inputs:
- state: The current state of the graph
Outputs:
- state: The current state
- world_state: The current world state
"""
_agent_name:ClassVar[str] = "world_state"
_agent_name: ClassVar[str] = "world_state"
def __init__(self, title="Request World State", **kwargs):
super().__init__(title=title, **kwargs)
def setup(self):
self.add_input("state")
self.add_output("state")
self.add_output("world_state", socket_type="dict")
async def run(self, state: GraphState):
scene:"Scene" = active_scene.get()
scene: "Scene" = active_scene.get()
world_state = await scene.world_state.request_update()
self.set_output_values({
"state": state,
"world_state": world_state
})
self.set_output_values({"state": state, "world_state": world_state})

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Union
from talemate.instance import get_agent
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Character, Scene
from talemate.tale_mate import Character, Scene
__all__ = [
@@ -46,7 +46,6 @@ async def activate_character(scene: "Scene", character: Union[str, "Character"])
if isinstance(character, str):
character = scene.get_character(character)
if character.name not in scene.inactive_characters:
# already activated
return False

View File

@@ -1,19 +1,17 @@
import os
import talemate.client.runpod
from talemate.client.anthropic import AnthropicClient
from talemate.client.base import ClientBase, ClientDisabledError
from talemate.client.cohere import CohereClient
from talemate.client.deepseek import DeepSeekClient
from talemate.client.google import GoogleClient
from talemate.client.groq import GroqClient
from talemate.client.koboldcpp import KoboldCppClient
from talemate.client.lmstudio import LMStudioClient
from talemate.client.mistral import MistralAIClient
from talemate.client.ollama import OllamaClient
from talemate.client.openai import OpenAIClient
from talemate.client.openrouter import OpenRouterClient
from talemate.client.openai_compat import OpenAICompatibleClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.tabbyapi import TabbyAPIClient
from talemate.client.textgenwebui import TextGeneratorWebuiClient
import talemate.client.runpod # noqa: F401
from talemate.client.anthropic import AnthropicClient # noqa: F401
from talemate.client.base import ClientBase, ClientDisabledError # noqa: F401
from talemate.client.cohere import CohereClient # noqa: F401
from talemate.client.deepseek import DeepSeekClient # noqa: F401
from talemate.client.google import GoogleClient # noqa: F401
from talemate.client.groq import GroqClient # noqa: F401
from talemate.client.koboldcpp import KoboldCppClient # noqa: F401
from talemate.client.lmstudio import LMStudioClient # noqa: F401
from talemate.client.mistral import MistralAIClient # noqa: F401
from talemate.client.ollama import OllamaClient # noqa: F401
from talemate.client.openai import OpenAIClient # noqa: F401
from talemate.client.openrouter import OpenRouterClient # noqa: F401
from talemate.client.openai_compat import OpenAICompatibleClient # noqa: F401
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register # noqa: F401
from talemate.client.tabbyapi import TabbyAPIClient # noqa: F401
from talemate.client.textgenwebui import TextGeneratorWebuiClient # noqa: F401

View File

@@ -39,6 +39,7 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
model: str = "claude-3-5-sonnet-latest"
double_coercion: str = None
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@@ -118,13 +119,13 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
self.current_status = status
data={
data = {
"error_action": error_action.model_dump() if error_action else None,
"double_coercion": self.double_coercion,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
data.update(self._common_status_data())
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
@@ -135,7 +136,10 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
)
def set_client(self, max_token_length: int = None):
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
if (
not self.anthropic_api_key
and not self.endpoint_override_base_url_configured
):
self.client = AsyncAnthropic(api_key="sk-1111")
log.error("No anthropic API key set")
if self.api_key_status:
@@ -175,10 +179,10 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
@@ -208,17 +212,18 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
Generates text from the given prompt and parameters.
"""
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
if (
not self.anthropic_api_key
and not self.endpoint_override_base_url_configured
):
raise Exception("No anthropic API key set")
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
system_message = self.get_system_message(kind)
messages = [
{"role": "user", "content": prompt.strip()}
]
messages = [{"role": "user", "content": prompt.strip()}]
if coercion_prompt:
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
@@ -228,7 +233,7 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
parameters=parameters,
system_message=system_message,
)
completion_tokens = 0
prompt_tokens = 0
@@ -240,22 +245,20 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
stream=True,
**parameters,
)
response = ""
async for event in stream:
if event.type == "content_block_delta":
content = event.delta.text
response += content
self.update_request_tokens(self.count_tokens(content))
elif event.type == "message_start":
prompt_tokens = event.message.usage.input_tokens
elif event.type == "message_delta":
completion_tokens += event.usage.output_tokens
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
@@ -267,5 +270,5 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
self.log.error("generate error", e=e)
emit("status", message="anthropic API: Permission Denied", status="error")
return ""
except Exception as e:
except Exception:
raise

View File

@@ -42,6 +42,7 @@ STOPPING_STRINGS = ["<|im_end|>", "</s>"]
# disable smart quotes until text rendering is refactored
REPLACE_SMART_QUOTES = True
class ClientDisabledError(OSError):
def __init__(self, client: "ClientBase"):
self.client = client
@@ -76,6 +77,7 @@ class CommonDefaults(pydantic.BaseModel):
data_format: Literal["yaml", "json"] | None = None
preset_group: str | None = None
class Defaults(CommonDefaults, pydantic.BaseModel):
api_url: str = "http://localhost:5000"
max_token_length: int = 8192
@@ -88,6 +90,7 @@ class FieldGroup(pydantic.BaseModel):
description: str
icon: str = "mdi-cog"
class ExtraField(pydantic.BaseModel):
name: str
type: str
@@ -97,6 +100,7 @@ class ExtraField(pydantic.BaseModel):
group: FieldGroup | None = None
note: ux_schema.Note | None = None
class ParameterReroute(pydantic.BaseModel):
talemate_parameter: str
client_parameter: str
@@ -117,23 +121,23 @@ class RequestInformation(pydantic.BaseModel):
start_time: float = pydantic.Field(default_factory=time.time)
end_time: float | None = None
tokens: int = 0
@pydantic.computed_field(description="Duration")
@property
def duration(self) -> float:
end_time = self.end_time or time.time()
return end_time - self.start_time
@pydantic.computed_field(description="Tokens per second")
@property
def rate(self) -> float:
try:
end_time = self.end_time or time.time()
return self.tokens / (end_time - self.start_time)
except:
except Exception:
pass
return 0
@pydantic.computed_field(description="Status")
@property
def status(self) -> str:
@@ -145,7 +149,7 @@ class RequestInformation(pydantic.BaseModel):
return "in progress"
else:
return "pending"
@pydantic.computed_field(description="Age")
@property
def age(self) -> float:
@@ -159,10 +163,12 @@ class ClientEmbeddingsStatus:
client: "ClientBase | None" = None
embedding_name: str | None = None
async_signals.register(
"client.embeddings_available",
)
class ClientBase:
api_url: str
model_name: str
@@ -183,10 +189,10 @@ class ClientBase:
rate_limit: int | None = None
client_type = "base"
request_information: RequestInformation | None = None
status_request_timeout:int = 2
system_prompts:SystemPrompts = SystemPrompts()
status_request_timeout: int = 2
system_prompts: SystemPrompts = SystemPrompts()
preset_group: str | None = ""
rate_limit_counter: CounterRateLimiter = None
@@ -216,7 +222,7 @@ class ClientBase:
self.max_token_length = (
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
)
self.set_client(max_token_length=self.max_token_length)
def __str__(self):
@@ -252,23 +258,23 @@ class ClientBase:
"temperature",
"max_tokens",
]
@property
def supports_embeddings(self) -> bool:
return False
@property
def embeddings_function(self):
return None
@property
def embeddings_status(self) -> bool:
return getattr(self, "_embeddings_status", False)
@property
def embeddings_model_name(self) -> str | None:
return getattr(self, "_embeddings_model_name", None)
@property
def embeddings_url(self) -> str:
return None
@@ -277,16 +283,16 @@ class ClientBase:
def embeddings_identifier(self) -> str:
return f"client-api/{self.name}/{self.embeddings_model_name}"
async def destroy(self, config:dict):
async def destroy(self, config: dict):
"""
This is called before the client is removed from talemate.instance.clients
Use this to perform any cleanup that is necessary.
If a subclass overrides this method, it should call super().destroy(config) in the
end of the method.
"""
if self.supports_embeddings:
self.remove_embeddings(config)
@@ -296,25 +302,28 @@ class ClientBase:
def set_client(self, **kwargs):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def set_embeddings(self):
log.debug("setting embeddings", client=self.name, supports_embeddings=self.supports_embeddings, embeddings_status=self.embeddings_status)
log.debug(
"setting embeddings",
client=self.name,
supports_embeddings=self.supports_embeddings,
embeddings_status=self.embeddings_status,
)
if not self.supports_embeddings or not self.embeddings_status:
return
config = load_config(as_model=True)
key = self.embeddings_identifier
if key in config.presets.embeddings:
log.debug("embeddings already set", client=self.name, key=key)
return config.presets.embeddings[key]
log.debug("setting embeddings", client=self.name, key=key)
config.presets.embeddings[key] = EmbeddingFunctionPreset(
embeddings="client-api",
client=self.name,
@@ -324,10 +333,10 @@ class ClientBase:
local=False,
custom=True,
)
save_config(config)
def remove_embeddings(self, config:dict | None = None):
def remove_embeddings(self, config: dict | None = None):
# remove all embeddings for this client
for key, value in list(config["presets"]["embeddings"].items()):
if value["client"] == self.name and value["embeddings"] == "client-api":
@@ -338,7 +347,9 @@ class ClientBase:
if isinstance(system_prompts, dict):
self.system_prompts = SystemPrompts(**system_prompts)
elif not isinstance(system_prompts, SystemPrompts):
raise ValueError("system_prompts must be a `dict` or `SystemPrompts` instance")
raise ValueError(
"system_prompts must be a `dict` or `SystemPrompts` instance"
)
else:
self.system_prompts = system_prompts
@@ -376,10 +387,10 @@ class ClientBase:
"""
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if self.double_coercion:
right = f"{self.double_coercion}\n\n{right}"
return prompt, right
return prompt, None
@@ -407,7 +418,7 @@ class ClientBase:
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
def _reconfigure_common_parameters(self, **kwargs):
@@ -415,12 +426,14 @@ class ClientBase:
self.rate_limit = kwargs["rate_limit"]
if self.rate_limit:
if not self.rate_limit_counter:
self.rate_limit_counter = CounterRateLimiter(rate_per_minute=self.rate_limit)
self.rate_limit_counter = CounterRateLimiter(
rate_per_minute=self.rate_limit
)
else:
self.rate_limit_counter.update_rate_limit(self.rate_limit)
else:
self.rate_limit_counter = None
if "data_format" in kwargs:
self.data_format = kwargs["data_format"]
@@ -478,19 +491,21 @@ class ClientBase:
- kind: the kind of generation
"""
app_config_system_prompts = client_context_attribute("app_config_system_prompts")
app_config_system_prompts = client_context_attribute(
"app_config_system_prompts"
)
if app_config_system_prompts:
self.system_prompts.parent = SystemPrompts(**app_config_system_prompts)
return self.system_prompts.get(kind, self.decensor_enabled)
def emit_status(self, processing: bool = None):
"""
Sets and emits the client status.
"""
if processing is not None:
self.processing = processing
@@ -516,7 +531,6 @@ class ClientBase:
)
if not has_prompt_template and self.auto_determine_prompt_template:
# only attempt to determine the prompt template once per model and
# only if the model does not already have a prompt template
@@ -581,15 +595,17 @@ class ClientBase:
"supports_embeddings": self.supports_embeddings,
"embeddings_status": self.embeddings_status,
"embeddings_model_name": self.embeddings_model_name,
"request_information": self.request_information.model_dump() if self.request_information else None,
"request_information": self.request_information.model_dump()
if self.request_information
else None,
}
extra_fields = getattr(self.Meta(), "extra_fields", {})
for field_name in extra_fields.keys():
common_data[field_name] = getattr(self, field_name, None)
return common_data
def populate_extra_fields(self, data: dict):
"""
Updates data with the extra fields from the client's Meta
@@ -665,7 +681,7 @@ class ClientBase:
agent_context.agent.inject_prompt_paramters(
parameters, kind, agent_context.action
)
if client_context_attribute(
"nuke_repetition"
) > 0.0 and self.jiggle_enabled_for(kind):
@@ -707,7 +723,6 @@ class ClientBase:
del parameters[key]
def finalize(self, parameters: dict, prompt: str):
prompt = util.replace_special_tokens(prompt)
for finalizer in self.finalizers:
@@ -740,22 +755,20 @@ class ClientBase:
"status", message="Error during generation (check logs)", status="error"
)
return ""
def _generate_task(self, prompt: str, parameters: dict, kind: str):
"""
Creates an asyncio task to generate text from the given prompt and parameters.
"""
return asyncio.create_task(self.generate(prompt, parameters, kind))
def _poll_interrupt(self):
"""
Creatates a task that continiously checks active_scene.cancel_requested and
will complete the task if it is requested.
"""
async def poll():
while True:
scene = active_scene.get()
@@ -763,58 +776,58 @@ class ClientBase:
break
await asyncio.sleep(0.3)
return GenerationCancelled("Generation cancelled")
return asyncio.create_task(poll())
async def _cancelable_generate(self, prompt: str, parameters: dict, kind: str) -> str | GenerationCancelled:
async def _cancelable_generate(
self, prompt: str, parameters: dict, kind: str
) -> str | GenerationCancelled:
"""
Queues the generation task and the poll task to be run concurrently.
If the poll task completes before the generation task, the generation task
will be cancelled.
If the generation task completes before the poll task, the poll task will
be cancelled.
"""
task_poll = self._poll_interrupt()
task_generate = self._generate_task(prompt, parameters, kind)
done, pending = await asyncio.wait(
[task_poll, task_generate],
return_when=asyncio.FIRST_COMPLETED
[task_poll, task_generate], return_when=asyncio.FIRST_COMPLETED
)
# cancel the remaining task
for task in pending:
task.cancel()
# return the result of the completed task
return done.pop().result()
async def abort_generation(self):
"""
This function can be overwritten to trigger an abortion at the other
side of the client.
So a generation is cancelled here, this can be used to cancel a generation
at the other side of the client.
"""
pass
def new_request(self):
"""
Creates a new request information object.
"""
self.request_information = RequestInformation()
def end_request(self):
"""
Ends the request information object.
"""
self.request_information.end_time = time.time()
def update_request_tokens(self, tokens: int, replace: bool = False):
"""
Updates the request information object with the number of tokens received.
@@ -824,7 +837,7 @@ class ClientBase:
self.request_information.tokens = tokens
else:
self.request_information.tokens += tokens
async def send_prompt(
self,
prompt: str,
@@ -837,13 +850,13 @@ class ClientBase:
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
try:
return await self._send_prompt(prompt, kind, finalize, retries)
except GenerationCancelled:
await self.abort_generation()
raise
async def _send_prompt(
self,
prompt: str,
@@ -856,47 +869,49 @@ class ClientBase:
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
try:
if self.rate_limit_counter:
aborted:bool = False
aborted: bool = False
while not self.rate_limit_counter.increment():
log.warn("Rate limit exceeded", client=self.name)
emit(
"rate_limited",
message="Rate limit exceeded",
status="error",
message="Rate limit exceeded",
status="error",
websocket_passthrough=True,
data={
"client": self.name,
"rate_limit": self.rate_limit,
"reset_time": self.rate_limit_counter.reset_time(),
}
},
)
scene = active_scene.get()
if not scene or not scene.active or scene.cancel_requested:
log.info("Rate limit exceeded, generation cancelled", client=self.name)
log.info(
"Rate limit exceeded, generation cancelled",
client=self.name,
)
aborted = True
break
await asyncio.sleep(1)
emit(
"rate_limit_reset",
message="Rate limit reset",
status="info",
websocket_passthrough=True,
data={"client": self.name}
data={"client": self.name},
)
if aborted:
raise GenerationCancelled("Generation cancelled")
except GenerationCancelled:
raise
except Exception as e:
except Exception:
log.error("Error during rate limit check", e=traceback.format_exc())
if not active_scene.get():
log.error("SceneInactiveError", scene=active_scene.get())
@@ -940,25 +955,25 @@ class ClientBase:
parameters=prompt_param,
)
prompt_sent = self.repetition_adjustment(finalized_prompt)
self.new_request()
response = await self._cancelable_generate(prompt_sent, prompt_param, kind)
self.end_request()
if isinstance(response, GenerationCancelled):
# generation was cancelled
raise response
#response = await self.generate(prompt_sent, prompt_param, kind)
# response = await self.generate(prompt_sent, prompt_param, kind)
response, finalized_prompt = await self.auto_break_repetition(
finalized_prompt, prompt_param, response, kind, retries
)
if REPLACE_SMART_QUOTES:
response = response.replace('', '"').replace('', '"')
response = response.replace("", '"').replace("", '"')
time_end = time.time()
@@ -992,9 +1007,9 @@ class ClientBase:
)
return response
except GenerationCancelled as e:
except GenerationCancelled:
raise
except Exception as e:
except Exception:
self.log.error("send_prompt error", e=traceback.format_exc())
emit(
"status", message="Error during generation (check logs)", status="error"
@@ -1004,7 +1019,7 @@ class ClientBase:
self.emit_status(processing=False)
self._returned_prompt_tokens = None
self._returned_response_tokens = None
if self.rate_limit_counter:
self.rate_limit_counter.increment()

View File

@@ -2,7 +2,13 @@ import pydantic
import structlog
from cohere import AsyncClientV2
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
from talemate.client.base import (
ClientBase,
ErrorAction,
ParameterReroute,
CommonDefaults,
ExtraField,
)
from talemate.client.registry import register
from talemate.client.remote import (
EndpointOverride,
@@ -51,7 +57,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
auto_break_repetition_enabled = False
decensor_enabled = True
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Cohere"
title: str = "Cohere"
@@ -115,12 +121,12 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
self.current_status = status
data={
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
data.update(self._common_status_data())
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
@@ -171,7 +177,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
@@ -200,7 +206,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
return prompt
def clean_prompt_parameters(self, parameters: dict):
super().clean_prompt_parameters(parameters)
# if temperature is set, it needs to be clamped between 0 and 1.0
@@ -240,7 +245,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
parameters=parameters,
system_message=system_message,
)
messages = [
{
"role": "system",
@@ -249,7 +254,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
{
"role": "user",
"content": human_message,
}
},
]
try:
@@ -290,5 +295,5 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
# self.log.error("generate error", e=e)
# emit("status", message="cohere API: Permission Denied", status="error")
# return ""
except Exception as e:
except Exception:
raise

View File

@@ -23,7 +23,10 @@ def model_to_dict_without_defaults(model_instance):
if field.default == model_dict.get(field_name):
del model_dict[field_name]
# special case for conversation context, dont copy if talking_character is None
if field_name == "conversation" and model_dict.get(field_name).get("talking_character") is None:
if (
field_name == "conversation"
and model_dict.get(field_name).get("talking_character") is None
):
del model_dict[field_name]
return model_dict
@@ -102,7 +105,7 @@ class ClientContext:
# Update the context data
self.token = context_data.set(data)
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Reset the context variable `context_data` to its previous values when exiting the context.

View File

@@ -1,8 +1,5 @@
import json
import pydantic
import structlog
import tiktoken
from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
@@ -103,12 +100,12 @@ class DeepSeekClient(ClientBase):
self.current_status = status
data={
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
data.update(self._common_status_data())
data.update(self._common_status_data())
emit(
"client_status",
message=self.client_type,
@@ -273,5 +270,5 @@ class DeepSeekClient(ClientBase):
self.log.error("generate error", e=e)
emit("status", message="DeepSeek API: Permission Denied", status="error")
return ""
except Exception as e:
except Exception:
raise

View File

@@ -7,7 +7,13 @@ from google import genai
import google.genai.types as genai_types
from google.genai.errors import APIError
from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute, CommonDefaults
from talemate.client.base import (
ClientBase,
ErrorAction,
ExtraField,
ParameterReroute,
CommonDefaults,
)
from talemate.client.registry import register
from talemate.client.remote import (
RemoteServiceMixin,
@@ -41,12 +47,14 @@ SUPPORTED_MODELS = [
"gemini-2.5-pro-preview-06-05",
]
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "gemini-2.0-flash"
disable_safety_settings: bool = False
double_coercion: str = None
class ClientConfig(EndpointOverride, BaseClientConfig):
disable_safety_settings: bool = False
@@ -80,7 +88,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
),
}
extra_fields.update(endpoint_override_extra_fields())
def __init__(self, model="gemini-2.0-flash", **kwargs):
self.model_name = model
@@ -114,24 +121,28 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
@property
def google_location(self):
return self.config.get("google").get("gcloud_location")
@property
def google_api_key(self):
return self.config.get("google").get("api_key")
@property
def vertexai_ready(self) -> bool:
return all([
self.google_credentials_path,
self.google_location,
])
return all(
[
self.google_credentials_path,
self.google_location,
]
)
@property
def developer_api_ready(self) -> bool:
return all([
self.google_api_key,
])
return all(
[
self.google_api_key,
]
)
@property
def using(self) -> str:
if self.developer_api_ready:
@@ -143,7 +154,11 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
@property
def ready(self):
# all google settings must be set
return self.vertexai_ready or self.developer_api_ready or self.endpoint_override_base_url_configured
return (
self.vertexai_ready
or self.developer_api_ready
or self.endpoint_override_base_url_configured
)
@property
def safety_settings(self):
@@ -179,10 +194,8 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
def http_options(self) -> genai_types.HttpOptions | None:
if not self.endpoint_override_base_url_configured:
return None
return genai_types.HttpOptions(
base_url=self.base_url
)
return genai_types.HttpOptions(base_url=self.base_url)
@property
def supported_parameters(self):
@@ -230,7 +243,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
data.update(self._common_status_data())
data.update(self._common_status_data())
self.populate_extra_fields(data)
if self.using == "VertexAI":
@@ -252,7 +265,9 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
try:
self.client.http_options.base_url = base_url
except Exception as e:
log.error("Error setting client base URL", error=e, client=self.client_type)
log.error(
"Error setting client base URL", error=e, client=self.client_type
)
def set_client(self, max_token_length: int = None, **kwargs):
if not self.ready:
@@ -283,7 +298,9 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
location=self.google_location,
)
else:
self.client = genai.Client(api_key=self.api_key or None, http_options=self.http_options)
self.client = genai.Client(
api_key=self.api_key or None, http_options=self.http_options
)
log.info(
"google set client",
@@ -292,7 +309,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
model=model,
)
def response_tokens(self, response:str):
def response_tokens(self, response: str):
"""Return token count for a response which may be a string or SDK object."""
return count_tokens(response)
@@ -309,10 +326,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
def clean_prompt_parameters(self, parameters: dict):
@@ -329,7 +346,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
"""
return prompt
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
@@ -342,7 +358,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
human_message = prompt.strip()
system_message = self.get_system_message(kind)
contents = [
genai_types.Content(
role="user",
@@ -350,10 +366,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
genai_types.Part.from_text(
text=human_message,
)
]
],
)
]
if coercion_prompt:
contents.append(
genai_types.Content(
@@ -362,7 +378,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
genai_types.Part.from_text(
text=coercion_prompt,
)
]
],
)
)
@@ -378,24 +394,23 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
try:
# Use streaming so we can update_Request_tokens incrementally
#stream = await chat.send_message_async(
# stream = await chat.send_message_async(
# human_message,
# safety_settings=self.safety_settings,
# generation_config=parameters,
# stream=True
#)
# )
stream = await self.client.aio.models.generate_content_stream(
model=self.model_name,
contents=contents,
config=genai_types.GenerateContentConfig(
safety_settings=self.safety_settings,
http_options=self.http_options,
**parameters
**parameters,
),
)
response = ""
async for chunk in stream:
@@ -425,5 +440,5 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
self.log.error("generate error", e=e)
emit("status", message="google API: API Error", status="error")
return ""
except Exception as e:
except Exception:
raise

View File

@@ -260,5 +260,5 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
self.log.error("generate error", e=e)
emit("status", message="OpenAI API: Permission Denied", status="error")
return ""
except Exception as e:
except Exception:
raise

View File

@@ -17,7 +17,7 @@ from talemate.client.base import (
ClientBase,
Defaults,
ParameterReroute,
ClientEmbeddingsStatus
ClientEmbeddingsStatus,
)
from talemate.client.registry import register
import talemate.emit.async_signals as async_signals
@@ -46,9 +46,13 @@ class KoboldEmbeddingFunction(EmbeddingFunction):
"""
Embed a list of input texts using the KoboldCPP embeddings endpoint.
"""
log.debug("KoboldCppEmbeddingFunction", api_url=self.api_url, model_name=self.model_name)
log.debug(
"KoboldCppEmbeddingFunction",
api_url=self.api_url,
model_name=self.model_name,
)
# Prepare the request payload for KoboldCPP. Include model name if required.
payload = {"input": texts}
if self.model_name is not None:
@@ -65,6 +69,8 @@ class KoboldEmbeddingFunction(EmbeddingFunction):
embeddings = [item["embedding"] for item in embedding_results]
return embeddings
@register()
class KoboldCppClient(ClientBase):
auto_determine_prompt_template: bool = True
@@ -147,7 +153,6 @@ class KoboldCppClient(ClientBase):
talemate_parameter="stopping_strings",
client_parameter="stop_sequence",
),
"xtc_threshold",
"xtc_probability",
"dry_multiplier",
@@ -155,7 +160,6 @@ class KoboldCppClient(ClientBase):
"dry_allowed_length",
"dry_sequence_breakers",
"smoothing_factor",
"temperature",
]
@@ -172,18 +176,18 @@ class KoboldCppClient(ClientBase):
@property
def supports_embeddings(self) -> bool:
return True
@property
def embeddings_url(self) -> str:
if self.is_openai:
return urljoin(self.api_url, "embeddings")
else:
return urljoin(self.api_url, "api/extra/embeddings")
@property
def embeddings_function(self):
return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name)
def api_endpoint_specified(self, url: str) -> bool:
return "/v1" in self.api_url
@@ -208,10 +212,10 @@ class KoboldCppClient(ClientBase):
# if self._embeddings_model_name is set, return it
if self.embeddings_model_name:
return self.embeddings_model_name
# otherwise, get the model name by doing a request to
# the embeddings endpoint with a single character
async with httpx.AsyncClient() as client:
response = await client.post(
self.embeddings_url,
@@ -219,37 +223,40 @@ class KoboldCppClient(ClientBase):
timeout=2,
headers=self.request_headers,
)
response_data = response.json()
self._embeddings_model_name = response_data.get("model")
return self._embeddings_model_name
async def get_embeddings_status(self):
url_version = urljoin(self.api_url, "api/extra/version")
async with httpx.AsyncClient() as client:
response = await client.get(url_version, timeout=2)
response_data = response.json()
self._embeddings_status = response_data.get("embeddings", False)
if not self.embeddings_status or self.embeddings_model_name:
return
await self.get_embeddings_model_name()
log.debug("KoboldCpp embeddings are enabled, suggesting embeddings", model_name=self.embeddings_model_name)
log.debug(
"KoboldCpp embeddings are enabled, suggesting embeddings",
model_name=self.embeddings_model_name,
)
self.set_embeddings()
await async_signals.get("client.embeddings_available").send(
ClientEmbeddingsStatus(
client=self,
embedding_name=self.embeddings_model_name,
)
)
async def get_model_name(self):
self.ensure_api_endpoint_specified()
try:
async with httpx.AsyncClient() as client:
response = await client.get(
@@ -275,7 +282,7 @@ class KoboldCppClient(ClientBase):
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
await self.get_embeddings_status()
return model_name
@@ -309,7 +316,6 @@ class KoboldCppClient(ClientBase):
tokencount = len(response.json().get("ids", []))
return tokencount
async def abort_generation(self):
"""
Trigger the stop generation endpoint
@@ -317,7 +323,7 @@ class KoboldCppClient(ClientBase):
if self.is_openai:
# openai api endpoint doesn't support abort
return
parts = urlparse(self.api_url)
url_abort = f"{parts.scheme}://{parts.netloc}/api/extra/abort"
async with httpx.AsyncClient() as client:
@@ -325,7 +331,7 @@ class KoboldCppClient(ClientBase):
url_abort,
headers=self.request_headers,
)
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
@@ -334,14 +340,16 @@ class KoboldCppClient(ClientBase):
return await self._generate_openai(prompt, parameters, kind)
else:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._generate_kcpp_stream, prompt, parameters, kind)
return await loop.run_in_executor(
None, self._generate_kcpp_stream, prompt, parameters, kind
)
def _generate_kcpp_stream(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
parameters["prompt"] = prompt.strip(" ")
response = ""
parameters["stream"] = True
stream_response = requests.post(
@@ -352,15 +360,15 @@ class KoboldCppClient(ClientBase):
stream=True,
)
stream_response.raise_for_status()
sse = sseclient.SSEClient(stream_response)
for event in sse.events():
payload = json.loads(event.data)
chunk = payload['token']
chunk = payload["token"]
response += chunk
self.update_request_tokens(self.count_tokens(chunk))
return response
async def _generate_openai(self, prompt: str, parameters: dict, kind: str):
@@ -447,7 +455,6 @@ class KoboldCppClient(ClientBase):
sd_models_url = urljoin(self.url, "/sdapi/v1/sd-models")
async with httpx.AsyncClient() as client:
try:
response = await client.get(url=sd_models_url, timeout=2)
except Exception as exc:

View File

@@ -37,13 +37,13 @@ class LMStudioClient(ClientBase):
def reconfigure(self, **kwargs):
super().reconfigure(**kwargs)
if self.client and self.client.base_url != self.api_url:
self.set_client()
async def get_model_name(self):
model_name = await super().get_model_name()
# model name comes back as a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash

View File

@@ -1,10 +1,14 @@
import pydantic
import structlog
from typing import Literal
from mistralai import Mistral
from mistralai.models.sdkerror import SDKError
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
from talemate.client.base import (
ClientBase,
ErrorAction,
CommonDefaults,
ExtraField,
)
from talemate.client.registry import register
from talemate.client.remote import (
EndpointOverride,
@@ -44,9 +48,11 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "open-mixtral-8x22b"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register()
class MistralAIClient(EndpointOverrideMixin, ClientBase):
"""
@@ -68,7 +74,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="open-mixtral-8x22b", **kwargs):
self.model_name = model
self.api_key_status = None
@@ -115,7 +121,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
model_name = "No model loaded"
self.current_status = status
data={
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
@@ -167,10 +173,10 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
def reconfigure(self, **kwargs):
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
self._reconfigure_common_parameters(**kwargs)
self._reconfigure_endpoint_override(**kwargs)
if kwargs.get("model"):
self.model_name = kwargs["model"]
self.set_client(kwargs.get("max_token_length"))
@@ -248,14 +254,16 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
)
response = ""
completion_tokens = 0
prompt_tokens = 0
async for event in event_stream:
if event.data.choices:
response += event.data.choices[0].delta.content
self.update_request_tokens(self.count_tokens(event.data.choices[0].delta.content))
self.update_request_tokens(
self.count_tokens(event.data.choices[0].delta.content)
)
if event.data.usage:
completion_tokens += event.data.usage.completion_tokens
prompt_tokens += event.data.usage.prompt_tokens
@@ -263,7 +271,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
#response = response.choices[0].message.content
# response = response.choices[0].message.content
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
@@ -282,12 +290,12 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
return response
except SDKError as e:
self.log.error("generate error", e=e)
if hasattr(e, 'status_code') and e.status_code in [403, 401]:
if hasattr(e, "status_code") and e.status_code in [403, 401]:
emit(
"status",
message="mistral.ai API: Permission Denied",
status="error",
)
return ""
except Exception as e:
except Exception:
raise

View File

@@ -117,7 +117,6 @@ class ModelPrompt:
prompt = f"{prompt}<|BOT|>"
if "<|BOT|>" in prompt:
response_str = f"{double_coercion}{response_str}"
if "\n<|BOT|>" in prompt:
@@ -264,7 +263,11 @@ class Mistralv7TekkenIdentifier(TemplateIdentifier):
template_str = "MistralV7Tekken"
def __call__(self, content: str):
return "[SYSTEM_PROMPT]" in content and "[INST]" in content and "[/INST]" in content
return (
"[SYSTEM_PROMPT]" in content
and "[INST]" in content
and "[/INST]" in content
)
@register_template_identifier

View File

@@ -1,11 +1,16 @@
import asyncio
import structlog
import httpx
import ollama
import time
from typing import Union
from talemate.client.base import STOPPING_STRINGS, ClientBase, CommonDefaults, ErrorAction, ParameterReroute, ExtraField
from talemate.client.base import (
STOPPING_STRINGS,
ClientBase,
CommonDefaults,
ErrorAction,
ParameterReroute,
ExtraField,
)
from talemate.client.registry import register
from talemate.config import Client as BaseClientConfig
@@ -14,27 +19,30 @@ log = structlog.get_logger("talemate.client.ollama")
FETCH_MODELS_INTERVAL = 15
class OllamaClientDefaults(CommonDefaults):
api_url: str = "http://localhost:11434" # Default Ollama URL
model: str = "" # Allow empty default, will fetch from Ollama
api_handles_prompt_template: bool = False
allow_thinking: bool = False
class ClientConfig(BaseClientConfig):
api_handles_prompt_template: bool = False
allow_thinking: bool = False
@register()
class OllamaClient(ClientBase):
"""
Ollama client for generating text using locally hosted models.
"""
auto_determine_prompt_template: bool = True
client_type = "ollama"
conversation_retries = 0
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "Ollama"
title: str = "Ollama"
@@ -56,27 +64,26 @@ class OllamaClient(ClientBase):
label="Allow thinking",
required=False,
description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.",
)
),
}
@property
def supported_parameters(self):
# Parameters supported by Ollama's generate endpoint
# Based on the API documentation
return [
"temperature",
"top_p",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
ParameterReroute(
talemate_parameter="repetition_penalty",
client_parameter="repeat_penalty"
client_parameter="repeat_penalty",
),
ParameterReroute(
talemate_parameter="max_tokens",
client_parameter="num_predict"
talemate_parameter="max_tokens", client_parameter="num_predict"
),
"stopping_strings",
# internal parameters that will be removed before sending
@@ -98,7 +105,13 @@ class OllamaClient(ClientBase):
"""
return self.allow_thinking
def __init__(self, model=None, api_handles_prompt_template=False, allow_thinking=False, **kwargs):
def __init__(
self,
model=None,
api_handles_prompt_template=False,
allow_thinking=False,
**kwargs,
):
self.model_name = model
self.api_handles_prompt_template = api_handles_prompt_template
self.allow_thinking = allow_thinking
@@ -114,16 +127,14 @@ class OllamaClient(ClientBase):
# Update model if provided
if kwargs.get("model"):
self.model_name = kwargs["model"]
# Create async client with the configured API URL
# Ollama's AsyncClient expects just the base URL without any path
self.client = ollama.AsyncClient(host=self.api_url)
self.api_handles_prompt_template = kwargs.get(
"api_handles_prompt_template", self.api_handles_prompt_template
)
self.allow_thinking = kwargs.get(
"allow_thinking", self.allow_thinking
)
self.allow_thinking = kwargs.get("allow_thinking", self.allow_thinking)
async def status(self):
"""
@@ -131,7 +142,7 @@ class OllamaClient(ClientBase):
Raises an error if no model name is returned.
:return: None
"""
if self.processing:
self.emit_status()
return
@@ -140,7 +151,7 @@ class OllamaClient(ClientBase):
self.connected = False
self.emit_status()
return
try:
# instead of using the client (which apparently cannot set a timeout per endpoint)
# we use httpx to check {api_url}/api/version to see if the server is running
@@ -148,7 +159,7 @@ class OllamaClient(ClientBase):
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_url}/api/version", timeout=2)
response.raise_for_status()
# if the server is running, fetch the available models
await self.fetch_available_models()
except Exception as e:
@@ -156,7 +167,7 @@ class OllamaClient(ClientBase):
self.connected = False
self.emit_status()
return
await super().status()
async def fetch_available_models(self):
@@ -165,7 +176,7 @@ class OllamaClient(ClientBase):
"""
if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL:
return self._available_models
response = await self.client.list()
models = response.get("models", [])
model_names = [model.model for model in models]
@@ -182,7 +193,7 @@ class OllamaClient(ClientBase):
async def get_model_name(self):
return self.model_name
def prompt_template(self, system_message: str, prompt: str):
if not self.api_handles_prompt_template:
return super().prompt_template(system_message, prompt)
@@ -201,10 +212,12 @@ class OllamaClient(ClientBase):
Tune parameters for Ollama's generate endpoint.
"""
super().tune_prompt_parameters(parameters, kind)
# Build stopping strings list
parameters["stop"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
parameters["stop"] = STOPPING_STRINGS + parameters.get(
"extra_stopping_strings", []
)
# Ollama uses num_predict instead of max_tokens
if "max_tokens" in parameters:
parameters["num_predict"] = parameters["max_tokens"]
@@ -215,7 +228,7 @@ class OllamaClient(ClientBase):
"""
# First let parent class handle parameter reroutes and cleanup
super().clean_prompt_parameters(parameters)
# Remove our internal parameters
if "extra_stopping_strings" in parameters:
del parameters["extra_stopping_strings"]
@@ -223,7 +236,7 @@ class OllamaClient(ClientBase):
del parameters["stopping_strings"]
if "stream" in parameters:
del parameters["stream"]
# Remove max_tokens as we've already converted it to num_predict
if "max_tokens" in parameters:
del parameters["max_tokens"]
@@ -237,12 +250,12 @@ class OllamaClient(ClientBase):
await self.get_model_name()
if not self.model_name:
raise Exception("No model specified or available in Ollama")
# Prepare options for Ollama
options = parameters
options["num_ctx"] = self.max_token_length
try:
# Use generate endpoint for completion
stream = await self.client.generate(
@@ -253,22 +266,21 @@ class OllamaClient(ClientBase):
think=self.can_think,
stream=True,
)
response = ""
async for part in stream:
content = part.response
response += content
self.update_request_tokens(self.count_tokens(content))
# Extract the response text
return response
except Exception as e:
log.error("Ollama generation error", error=str(e), model=self.model_name)
raise ErrorAction(
message=f"Ollama generation failed: {str(e)}",
title="Generation Error"
message=f"Ollama generation failed: {str(e)}", title="Generation Error"
)
async def abort_generation(self):
@@ -284,7 +296,7 @@ class OllamaClient(ClientBase):
Adjusts temperature and repetition_penalty by random values.
"""
import random
temp = prompt_config["temperature"]
rep_pen = prompt_config.get("repetition_penalty", 1.0)
@@ -302,12 +314,12 @@ class OllamaClient(ClientBase):
# Handle model update
if kwargs.get("model"):
self.model_name = kwargs["model"]
super().reconfigure(**kwargs)
# Re-initialize client if API URL changed or model changed
if "api_url" in kwargs or "model" in kwargs:
self.set_client(**kwargs)
if "api_handles_prompt_template" in kwargs:
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]

View File

@@ -108,9 +108,11 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = "gpt-4o"
class ClientConfig(EndpointOverride, BaseClientConfig):
pass
@register()
class OpenAIClient(EndpointOverrideMixin, ClientBase):
"""
@@ -123,7 +125,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
# TODO: make this configurable?
decensor_enabled = False
config_cls = ClientConfig
class Meta(ClientBase.Meta):
name_prefix: str = "OpenAI"
title: str = "OpenAI"
@@ -132,6 +134,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
def __init__(self, model="gpt-4o", **kwargs):
self.model_name = model
self.api_key_status = None
@@ -181,12 +184,12 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
self.current_status = status
data={
data = {
"error_action": error_action.model_dump() if error_action else None,
"meta": self.Meta().model_dump(),
"enabled": self.enabled,
}
data.update(self._common_status_data())
data.update(self._common_status_data())
emit(
"client_status",
@@ -305,32 +308,34 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
human_message = {"role": "user", "content": prompt.strip()}
system_message = {"role": "system", "content": self.get_system_message(kind)}
# o1 and o3 models don't support system_message
if "o1" in self.model_name or "o3" in self.model_name:
messages=[human_message]
messages = [human_message]
# paramters need to be munged
# `max_tokens` becomes `max_completion_tokens`
if "max_tokens" in parameters:
parameters["max_completion_tokens"] = parameters.pop("max_tokens")
# temperature forced to 1
if "temperature" in parameters:
log.debug(f"{self.model_name} does not support temperature, forcing to 1")
log.debug(
f"{self.model_name} does not support temperature, forcing to 1"
)
parameters["temperature"] = 1
unsupported_params = [
"presence_penalty",
"top_p",
]
for param in unsupported_params:
if param in parameters:
log.debug(f"{self.model_name} does not support {param}, removing")
parameters.pop(param)
else:
messages=[system_message, human_message]
messages = [system_message, human_message]
self.log.debug(
"generate",
@@ -346,7 +351,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
stream=True,
**parameters,
)
response = ""
# Iterate over streamed chunks
@@ -359,9 +364,9 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
response += content_piece
# Incrementally track token usage
self.update_request_tokens(self.count_tokens(content_piece))
#self._returned_prompt_tokens = self.prompt_tokens(prompt)
#self._returned_response_tokens = self.response_tokens(response)
# self._returned_prompt_tokens = self.prompt_tokens(prompt)
# self._returned_response_tokens = self.response_tokens(response)
# older models don't support json_object response coersion
# and often like to return the response wrapped in ```json
@@ -382,5 +387,5 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
self.log.error("generate error", e=e)
emit("status", message="OpenAI API: Permission Denied", status="error")
return ""
except Exception as e:
except Exception:
raise

View File

@@ -1,9 +1,8 @@
import random
import urllib
import pydantic
import structlog
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
from openai import AsyncOpenAI, PermissionDeniedError
from talemate.client.base import ClientBase, ExtraField
from talemate.client.registry import register
@@ -93,7 +92,6 @@ class OpenAICompatibleClient(ClientBase):
)
def prompt_template(self, system_message: str, prompt: str):
log.debug(
"IS API HANDLING PROMPT TEMPLATE",
api_handles_prompt_template=self.api_handles_prompt_template,
@@ -130,7 +128,10 @@ class OpenAICompatibleClient(ClientBase):
)
human_message = {"role": "user", "content": prompt.strip()}
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], stream=False, **parameters
model=self.model_name,
messages=[human_message],
stream=False,
**parameters,
)
response = response.choices[0].message.content
return self.process_response_for_indirect_coercion(prompt, response)
@@ -177,7 +178,7 @@ class OpenAICompatibleClient(ClientBase):
if "double_coercion" in kwargs:
self.double_coercion = kwargs["double_coercion"]
if "rate_limit" in kwargs:
self.rate_limit = kwargs["rate_limit"]

View File

@@ -21,25 +21,25 @@ AVAILABLE_MODELS = []
DEFAULT_MODEL = ""
MODELS_FETCHED = False
async def fetch_available_models(api_key: str = None):
"""Fetch available models from OpenRouter API"""
global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED
if not api_key:
return []
if MODELS_FETCHED:
return AVAILABLE_MODELS
# Only fetch if we haven't already or if explicitly requested
if AVAILABLE_MODELS and not api_key:
return AVAILABLE_MODELS
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://openrouter.ai/api/v1/models",
timeout=10.0
"https://openrouter.ai/api/v1/models", timeout=10.0
)
if response.status_code == 200:
data = response.json()
@@ -51,23 +51,26 @@ async def fetch_available_models(api_key: str = None):
AVAILABLE_MODELS = sorted(models)
log.debug(f"Fetched {len(AVAILABLE_MODELS)} models from OpenRouter")
else:
log.warning(f"Failed to fetch models from OpenRouter: {response.status_code}")
log.warning(
f"Failed to fetch models from OpenRouter: {response.status_code}"
)
except Exception as e:
log.error(f"Error fetching models from OpenRouter: {e}")
MODELS_FETCHED = True
return AVAILABLE_MODELS
def fetch_models_sync(event):
api_key = event.data.get("openrouter", {}).get("api_key")
loop = asyncio.get_event_loop()
loop.run_until_complete(fetch_available_models(api_key))
handlers["config_saved"].connect(fetch_models_sync)
handlers["talemate_started"].connect(fetch_models_sync)
class Defaults(CommonDefaults, pydantic.BaseModel):
max_token_length: int = 16384
model: str = DEFAULT_MODEL
@@ -89,7 +92,9 @@ class OpenRouterClient(ClientBase):
name_prefix: str = "OpenRouter"
title: str = "OpenRouter"
manual_model: bool = True
manual_model_choices: list[str] = pydantic.Field(default_factory=lambda: AVAILABLE_MODELS)
manual_model_choices: list[str] = pydantic.Field(
default_factory=lambda: AVAILABLE_MODELS
)
requires_prompt_template: bool = False
defaults: Defaults = Defaults()
@@ -169,7 +174,7 @@ class OpenRouterClient(ClientBase):
def set_client(self, max_token_length: int = None):
# Unlike other clients, we don't need to set up a client instance
# We'll use httpx directly in the generate method
if not self.openrouter_api_key:
log.error("No OpenRouter API key set")
if self.api_key_status:
@@ -183,7 +188,7 @@ class OpenRouterClient(ClientBase):
if max_token_length and not isinstance(max_token_length, int):
max_token_length = int(max_token_length)
# Set max token length (default to 16k if not specified)
self.max_token_length = max_token_length or 16384
@@ -221,7 +226,7 @@ class OpenRouterClient(ClientBase):
self._models_fetched = True
# Update the Meta class with new model choices
self.Meta.manual_model_choices = AVAILABLE_MODELS
self.emit_status()
def prompt_template(self, system_message: str, prompt: str):
@@ -240,13 +245,13 @@ class OpenRouterClient(ClientBase):
raise Exception("No OpenRouter API key set")
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
# Prepare messages for chat completion
messages = [
{"role": "system", "content": self.get_system_message(kind)},
{"role": "user", "content": prompt.strip()}
{"role": "user", "content": prompt.strip()},
]
if coercion_prompt:
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
@@ -255,7 +260,7 @@ class OpenRouterClient(ClientBase):
"model": self.model_name,
"messages": messages,
"stream": True,
**parameters
**parameters,
}
self.log.debug(
@@ -264,7 +269,7 @@ class OpenRouterClient(ClientBase):
parameters=parameters,
model=self.model_name,
)
response_text = ""
buffer = ""
completion_tokens = 0
@@ -279,46 +284,52 @@ class OpenRouterClient(ClientBase):
"Content-Type": "application/json",
},
json=payload,
timeout=120.0 # 2 minute timeout for generation
timeout=120.0, # 2 minute timeout for generation
) as response:
async for chunk in response.aiter_text():
buffer += chunk
while True:
# Find the next complete SSE line
line_end = buffer.find('\n')
line_end = buffer.find("\n")
if line_end == -1:
break
line = buffer[:line_end].strip()
buffer = buffer[line_end + 1:]
if line.startswith('data: '):
buffer = buffer[line_end + 1 :]
if line.startswith("data: "):
data = line[6:]
if data == '[DONE]':
if data == "[DONE]":
break
try:
data_obj = json.loads(data)
content = data_obj["choices"][0]["delta"].get("content")
content = data_obj["choices"][0]["delta"].get(
"content"
)
usage = data_obj.get("usage", {})
completion_tokens += usage.get("completion_tokens", 0)
completion_tokens += usage.get(
"completion_tokens", 0
)
prompt_tokens += usage.get("prompt_tokens", 0)
if content:
response_text += content
# Update tokens as content streams in
self.update_request_tokens(self.count_tokens(content))
self.update_request_tokens(
self.count_tokens(content)
)
except json.JSONDecodeError:
pass
# Extract the response content
response_content = response_text
self._returned_prompt_tokens = prompt_tokens
self._returned_response_tokens = completion_tokens
return response_content
except httpx.ConnectTimeout:
self.log.error("OpenRouter API timeout")
emit("status", message="OpenRouter API: Request timed out", status="error")
@@ -326,4 +337,4 @@ class OpenRouterClient(ClientBase):
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message=f"OpenRouter API Error: {str(e)}", status="error")
raise
raise

View File

@@ -16,11 +16,6 @@ __all__ = [
"preset_for_kind",
"make_kind",
"max_tokens_for_kind",
"PRESET_TALEMATE_CONVERSATION",
"PRESET_TALEMATE_CREATOR",
"PRESET_LLAMA_PRECISE",
"PRESET_DIVINE_INTELLECT",
"PRESET_SIMPLE_1",
]
log = structlog.get_logger("talemate.client.presets")
@@ -42,27 +37,31 @@ def sync_config(event):
)
CONFIG["inference_groups"] = {
group: InferencePresetGroup(**data)
for group, data in event.data.get("presets", {}).get("inference_groups", {}).items()
for group, data in event.data.get("presets", {})
.get("inference_groups", {})
.items()
}
handlers["config_saved"].connect(sync_config)
def get_inference_parameters(preset_name: str, group:str|None = None) -> dict:
def get_inference_parameters(preset_name: str, group: str | None = None) -> dict:
"""
Returns the inference parameters for the given preset name.
"""
presets = CONFIG["inference"].model_dump()
if group:
try:
group_presets = CONFIG["inference_groups"].get(group).model_dump()
presets.update(group_presets["presets"])
except AttributeError:
log.warning(f"Invalid preset group referenced: {group}. Falling back to defaults.")
log.warning(
f"Invalid preset group referenced: {group}. Falling back to defaults."
)
if preset_name in presets:
return presets[preset_name]
@@ -145,7 +144,7 @@ def preset_for_kind(kind: str, client: "ClientBase") -> dict:
presets=CONFIG["inference"],
)
preset_name = "scene_direction"
set_client_context_attribute("inference_preset", preset_name)
return get_inference_parameters(preset_name, client.preset_group)
@@ -197,29 +196,26 @@ def max_tokens_for_kind(kind: str, total_budget: int) -> int:
return value
if token_value is not None:
return token_value
# finally check if splitting last item off of _ is a number, and then just
# return that number
kind_split = kind.split("_")[-1]
if kind_split.isdigit():
return int(kind_split)
return 150 # Default value if none of the kinds match
def make_kind(action_type: str, length: int, expect_json:bool=False) -> str:
def make_kind(action_type: str, length: int, expect_json: bool = False) -> str:
"""
Creates a kind string based on the preset_arch_type and length.
"""
if action_type == "analyze" and not expect_json:
kind = f"investigate"
kind = "investigate"
else:
kind = action_type
kind = f"{kind}_{length}"
return kind
return kind

View File

@@ -1,30 +1,31 @@
from limits.strategies import MovingWindowRateLimiter
from limits.storage import MemoryStorage
from limits import parse, RateLimitItemPerMinute
from limits import RateLimitItemPerMinute
import time
__all__ = ["CounterRateLimiter"]
class CounterRateLimiter:
def __init__(self, rate_per_minute:int=99999, identifier:str="ratelimit"):
def __init__(self, rate_per_minute: int = 99999, identifier: str = "ratelimit"):
self.storage = MemoryStorage()
self.limiter = MovingWindowRateLimiter(self.storage)
self.rate = RateLimitItemPerMinute(rate_per_minute, 1)
self.identifier = identifier
def update_rate_limit(self, rate_per_minute:int):
def update_rate_limit(self, rate_per_minute: int):
"""Update the rate limit with a new value"""
self.rate = RateLimitItemPerMinute(rate_per_minute, 1)
def increment(self) -> bool:
limiter = self.limiter
rate = self.rate
return limiter.hit(rate, self.identifier)
def reset_time(self) -> float:
"""
Returns the time in seconds until the rate limit is reset
"""
window = self.limiter.get_window_stats(self.rate, self.identifier)
return window.reset_time - time.time()
return window.reset_time - time.time()

View File

@@ -18,56 +18,68 @@ __all__ = [
"EndpointOverrideAPIKeyField",
]
def endpoint_override_extra_fields():
return {
"override_base_url": EndpointOverrideBaseURLField(),
"override_api_key": EndpointOverrideAPIKeyField(),
}
class EndpointOverride(pydantic.BaseModel):
override_base_url: str | None = None
override_api_key: str | None = None
class EndpointOverrideGroup(FieldGroup):
name: str = "endpoint_override"
label: str = "Endpoint Override"
description: str = ("Override the default base URL used by this client to access the {client_type} service API.\n\n"
"IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. "
"If the override endpoint requires an API key, enter it below.")
description: str = (
"Override the default base URL used by this client to access the {client_type} service API.\n\n"
"IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. "
"If the override endpoint requires an API key, enter it below."
)
icon: str = "mdi-api"
class EndpointOverrideField(ExtraField):
group: EndpointOverrideGroup = pydantic.Field(default_factory=EndpointOverrideGroup)
class EndpointOverrideBaseURLField(EndpointOverrideField):
name: str = "override_base_url"
type: str = "text"
label: str = "Base URL"
required: bool = False
description: str = "Override the base URL for the remote service"
class EndpointOverrideAPIKeyField(EndpointOverrideField):
name: str = "override_api_key"
type: str = "password"
label: str = "API Key"
required: bool = False
description: str = "Override the API key for the remote service"
note: ux_schema.Note = pydantic.Field(default_factory=lambda: ux_schema.Note(
text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.",
color="warning",
))
note: ux_schema.Note = pydantic.Field(
default_factory=lambda: ux_schema.Note(
text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.",
color="warning",
)
)
class EndpointOverrideMixin:
override_base_url: str | None = None
override_api_key: str | None = None
def set_client_api_key(self, api_key: str | None):
if getattr(self, "client", None):
try:
self.client.api_key = api_key
except Exception as e:
log.error("Error setting client API key", error=e, client=self.client_type)
log.error(
"Error setting client API key", error=e, client=self.client_type
)
@property
def api_key(self) -> str | None:
@@ -84,14 +96,17 @@ class EndpointOverrideMixin:
@property
def endpoint_override_base_url_configured(self) -> bool:
return self.override_base_url and self.override_base_url.strip()
@property
def endpoint_override_api_key_configured(self) -> bool:
return self.override_api_key and self.override_api_key.strip()
@property
def endpoint_override_fully_configured(self) -> bool:
return self.endpoint_override_base_url_configured and self.endpoint_override_api_key_configured
return (
self.endpoint_override_base_url_configured
and self.endpoint_override_api_key_configured
)
def _reconfigure_endpoint_override(self, **kwargs):
if "override_base_url" in kwargs:
@@ -100,13 +115,13 @@ class EndpointOverrideMixin:
if getattr(self, "client", None) and orig != self.override_base_url:
log.info("Reconfiguring client base URL", new=self.override_base_url)
self.set_client(kwargs.get("max_token_length"))
if "override_api_key" in kwargs:
self.override_api_key = kwargs["override_api_key"]
self.set_client_api_key(self.override_api_key)
class RemoteServiceMixin:
class RemoteServiceMixin:
def prompt_template(self, system_message: str, prompt: str):
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)

View File

@@ -4,8 +4,6 @@ connection for the pod. This is a simple wrapper around the runpod module.
"""
import asyncio
import json
import os
import dotenv
import runpod

View File

@@ -27,7 +27,6 @@ PROMPT_TEMPLATE_MAP = {
"world_state": "world_state.system-analyst-no-decensor",
"summarize": "summarizer.system-no-decensor",
"visualize": "visual.system-no-decensor",
# contains some minor attempts at keeping the LLM from generating
# refusals to generate certain types of content
"roleplay_decensor": "conversation.system",
@@ -42,32 +41,39 @@ PROMPT_TEMPLATE_MAP = {
"visualize_decensor": "visual.system",
}
def cache_all() -> dict:
for key in PROMPT_TEMPLATE_MAP:
render_prompt(key)
return RENDER_CACHE.copy()
def render_prompt(kind:str, decensor:bool=False):
def render_prompt(kind: str, decensor: bool = False):
# work around circular import issue
# TODO: refactor to avoid circular import
from talemate.prompts import Prompt
if kind not in PROMPT_TEMPLATE_MAP:
log.warning(f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}")
log.warning(
f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}"
)
return ""
if decensor:
key = f"{kind}_decensor"
else:
key = kind
if key not in PROMPT_TEMPLATE_MAP:
log.warning(f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}", key=key)
log.warning(
f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}",
key=key,
)
return ""
if key in RENDER_CACHE:
return RENDER_CACHE[key]
prompt = str(Prompt.get(PROMPT_TEMPLATE_MAP[key]))
RENDER_CACHE[key] = prompt
@@ -77,17 +83,17 @@ def render_prompt(kind:str, decensor:bool=False):
class SystemPrompts(pydantic.BaseModel):
"""
System prompts and a normalized the way to access them.
Allows specification of a parent "SystemPrompts" instance that will be
used as a fallback, and if not so specified, will default to the
system prompts in the globals via lambda functions that render
the templates.
The globals that exist now will be deprecated in favor of this later.
"""
parent: "SystemPrompts | None" = pydantic.Field(default=None, exclude=True)
roleplay: str | None = None
narrator: str | None = None
creator: str | None = None
@@ -98,7 +104,7 @@ class SystemPrompts(pydantic.BaseModel):
world_state: str | None = None
summarize: str | None = None
visualize: str | None = None
roleplay_decensor: str | None = None
narrator_decensor: str | None = None
creator_decensor: str | None = None
@@ -109,64 +115,61 @@ class SystemPrompts(pydantic.BaseModel):
world_state_decensor: str | None = None
summarize_decensor: str | None = None
visualize_decensor: str | None = None
class Config:
exclude_none = True
exclude_unset = True
@property
def defaults(self) -> dict:
return RENDER_CACHE.copy()
def alias(self, alias:str) -> str:
def alias(self, alias: str) -> str:
if alias in PROMPT_TEMPLATE_MAP:
return alias
if "narrate" in alias:
return "narrator"
if "direction" in alias or "director" in alias:
return "director"
if "create" in alias or "creative" in alias:
return "creator"
if "conversation" in alias or "roleplay" in alias:
return "roleplay"
if "basic" in alias:
return "basic"
if "edit" in alias:
return "editor"
if "world_state" in alias:
return "world_state"
if "analyze_freeform" in alias or "investigate" in alias:
return "analyst_freeform"
if "analyze" in alias or "analyst" in alias or "analytical" in alias:
return "analyst"
if "summarize" in alias or "summarization" in alias:
return "summarize"
if "visual" in alias:
return "visualize"
return alias
def get(self, kind:str, decensor:bool=False) -> str:
def get(self, kind: str, decensor: bool = False) -> str:
kind = self.alias(kind)
key = f"{kind}_decensor" if decensor else kind
if getattr(self, key):
return getattr(self, key)
if self.parent is not None:
return self.parent.get(kind, decensor)
return render_prompt(kind, decensor)
return render_prompt(kind, decensor)

View File

@@ -1,5 +1,4 @@
import random
from typing import Literal
import json
import httpx
import pydantic
@@ -103,7 +102,6 @@ class TabbyAPIClient(ClientBase):
)
def prompt_template(self, system_message: str, prompt: str):
log.debug(
"IS API HANDLING PROMPT TEMPLATE",
api_handles_prompt_template=self.api_handles_prompt_template,
@@ -196,22 +194,18 @@ class TabbyAPIClient(ClientBase):
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
url,
headers=headers,
json=payload,
timeout=120.0
"POST", url, headers=headers, json=payload, timeout=120.0
) as response:
async for chunk in response.aiter_text():
buffer += chunk
while True:
line_end = buffer.find('\n')
line_end = buffer.find("\n")
if line_end == -1:
break
line = buffer[:line_end].strip()
buffer = buffer[line_end + 1:]
buffer = buffer[line_end + 1 :]
if not line:
continue
@@ -226,7 +220,7 @@ class TabbyAPIClient(ClientBase):
choice = data_obj.get("choices", [{}])[0]
# Chat completions use delta -> content.
# Chat completions use delta -> content.
delta = choice.get("delta", {})
content = (
delta.get("content")
@@ -235,12 +229,16 @@ class TabbyAPIClient(ClientBase):
)
usage = data_obj.get("usage", {})
completion_tokens = usage.get("completion_tokens", 0)
completion_tokens = usage.get(
"completion_tokens", 0
)
prompt_tokens = usage.get("prompt_tokens", 0)
if content:
response_text += content
self.update_request_tokens(self.count_tokens(content))
self.update_request_tokens(
self.count_tokens(content)
)
except json.JSONDecodeError:
# ignore malformed json chunks
pass
@@ -251,7 +249,9 @@ class TabbyAPIClient(ClientBase):
if is_chat:
# Process indirect coercion
response_text = self.process_response_for_indirect_coercion(prompt, response_text)
response_text = self.process_response_for_indirect_coercion(
prompt, response_text
)
return response_text
@@ -265,7 +265,9 @@ class TabbyAPIClient(ClientBase):
return ""
except Exception as e:
self.log.error("generate error", e=e)
emit("status", message="Error during generation (check logs)", status="error")
emit(
"status", message="Error during generation (check logs)", status="error"
)
return ""
def reconfigure(self, **kwargs):
@@ -287,7 +289,7 @@ class TabbyAPIClient(ClientBase):
self.double_coercion = kwargs["double_coercion"]
self._reconfigure_common_parameters(**kwargs)
self.set_client(**kwargs)
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:

View File

@@ -8,7 +8,7 @@ import httpx
import structlog
from openai import AsyncOpenAI
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults
from talemate.client.registry import register
log = structlog.get_logger("talemate.client.textgenwebui")
@@ -103,7 +103,6 @@ class TextGeneratorWebuiClient(ClientBase):
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
def finalize_llama3(self, parameters: dict, prompt: str) -> tuple[str, bool]:
if "<|eot_id|>" not in prompt:
return prompt, False
@@ -122,7 +121,6 @@ class TextGeneratorWebuiClient(ClientBase):
return prompt, True
def finalize_YI(self, parameters: dict, prompt: str) -> tuple[str, bool]:
if not self.model_name:
return prompt, False
@@ -141,7 +139,6 @@ class TextGeneratorWebuiClient(ClientBase):
return prompt, True
async def get_model_name(self):
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.api_url}/v1/internal/model/info",
@@ -170,14 +167,16 @@ class TextGeneratorWebuiClient(ClientBase):
async def generate(self, prompt: str, parameters: dict, kind: str):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._generate, prompt, parameters, kind)
return await loop.run_in_executor(
None, self._generate, prompt, parameters, kind
)
def _generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
parameters["prompt"] = prompt.strip(" ")
response = ""
parameters["stream"] = True
stream_response = requests.post(
@@ -188,17 +187,16 @@ class TextGeneratorWebuiClient(ClientBase):
stream=True,
)
stream_response.raise_for_status()
sse = sseclient.SSEClient(stream_response)
for event in sse.events():
payload = json.loads(event.data)
chunk = payload['choices'][0]['text']
chunk = payload["choices"][0]["text"]
response += chunk
self.update_request_tokens(self.count_tokens(chunk))
return response
return response
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""

View File

@@ -1,5 +1,3 @@
import copy
import random
from urllib.parse import urljoin as _urljoin
__all__ = ["urljoin"]

View File

@@ -1,14 +1,27 @@
from .base import TalemateCommand
from .cmd_characters import *
from .cmd_debug_tools import *
from .cmd_rebuild_archive import CmdRebuildArchive
from .cmd_rename import CmdRename
from .cmd_regenerate import *
from .cmd_reset import CmdReset
from .cmd_save import CmdSave
from .cmd_save_as import CmdSaveAs
from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene
from .cmd_time_util import *
from .cmd_tts import *
from .cmd_world_state import *
from .manager import Manager
from .base import TalemateCommand # noqa: F401
from .cmd_characters import CmdActivateCharacter, CmdDeactivateCharacter # noqa: F401
from .cmd_debug_tools import (
CmdPromptChangeSectioning, # noqa: F401
CmdSummarizerUpdateLayeredHistory, # noqa: F401
CmdSummarizerResetLayeredHistory, # noqa: F401
CmdSummarizerContextInvestigation, # noqa: F401
)
from .cmd_rebuild_archive import CmdRebuildArchive # noqa: F401
from .cmd_rename import CmdRename # noqa: F401
from .cmd_regenerate import CmdRegenerate # noqa: F401
from .cmd_reset import CmdReset # noqa: F401
from .cmd_save import CmdSave # noqa: F401
from .cmd_save_as import CmdSaveAs # noqa: F401
from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene # noqa: F401
from .cmd_time_util import CmdAdvanceTime # noqa: F401
from .cmd_tts import CmdTestTTS # noqa: F401
from .cmd_world_state import (
CmdAddReinforcement, # noqa: F401
CmdApplyWorldStateTemplate, # noqa: F401
CmdCheckPinConditions, # noqa: F401
CmdDetermineCharacterDevelopment, # noqa: F401
CmdRemoveReinforcement, # noqa: F401
CmdSummarizeAndPin, # noqa: F401
CmdUpdateReinforcements, # noqa: F401
)
from .manager import Manager # noqa: F401

View File

@@ -1,6 +1,4 @@
import asyncio
import json
import logging
import structlog
@@ -19,7 +17,6 @@ __all__ = [
log = structlog.get_logger("talemate.commands.cmd_debug_tools")
@register
class CmdPromptChangeSectioning(TalemateCommand):
"""
@@ -95,7 +92,8 @@ class CmdSummarizerUpdateLayeredHistory(TalemateCommand):
summarizer = get_agent("summarizer")
await summarizer.summarize_to_layered_history()
@register
class CmdSummarizerResetLayeredHistory(TalemateCommand):
"""
@@ -108,16 +106,17 @@ class CmdSummarizerResetLayeredHistory(TalemateCommand):
async def run(self):
summarizer = get_agent("summarizer")
# if arg is provided remove the last n layers
if self.args:
n = int(self.args[0])
self.scene.layered_history = self.scene.layered_history[:-n]
else:
self.scene.layered_history = []
await summarizer.summarize_to_layered_history()
@register
class CmdSummarizerContextInvestigation(TalemateCommand):
"""
@@ -135,11 +134,10 @@ class CmdSummarizerContextInvestigation(TalemateCommand):
if not self.args:
self.emit("system", "You must specify a query")
return
await summarizer.request_context_investigations(self.args[0], max_calls=1)
@register
class CmdMemoryCompareStrings(TalemateCommand):
"""
@@ -156,13 +154,14 @@ class CmdMemoryCompareStrings(TalemateCommand):
if not self.args:
self.emit("system", "You must specify two strings to compare")
return
string1 = self.args[0]
string2 = self.args[1]
result = await memory.compare_strings(string1, string2)
self.emit("system", f"The strings are {result['cosine_similarity']} similar")
@register
class CmdRunEditorRevision(TalemateCommand):
"""
@@ -172,14 +171,13 @@ class CmdRunEditorRevision(TalemateCommand):
name = "run_editor_revision"
description = "Run the editor revision"
aliases = ["run_revision"]
async def run(self):
editor = get_agent("editor")
scene = self.scene
last_message = scene.history[-1]
result = await editor.revision_detect_bad_prose(str(last_message))
self.emit("system", f"Result: {result}")

View File

@@ -1,5 +1,3 @@
import asyncio
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import emit
@@ -29,7 +27,7 @@ class CmdRebuildArchive(TalemateCommand):
]
self.scene.ts = "PT0S"
memory.delete({"typ": "history"})
entries = 0
@@ -42,9 +40,9 @@ class CmdRebuildArchive(TalemateCommand):
)
more = await summarizer.agent.build_archive(self.scene)
self.scene.sync_time()
entries += 1
if not more:
break

View File

@@ -1,6 +1,6 @@
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.emit import emit, wait_for_input, wait_for_input_yesno
from talemate.emit import wait_for_input_yesno
from talemate.exceptions import ResetScene

View File

@@ -1,5 +1,3 @@
import asyncio
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register

View File

@@ -22,10 +22,13 @@ class CmdSetEnvironmentToScene(TalemateCommand):
player_character = self.scene.get_player_character()
if not player_character:
self.system_message("No characters found - cannot switch to gameplay mode.", meta={
"icon": "mdi-alert",
"color": "warning",
})
self.system_message(
"No characters found - cannot switch to gameplay mode.",
meta={
"icon": "mdi-alert",
"color": "warning",
},
)
return True
self.scene.set_environment("scene")

View File

@@ -2,11 +2,6 @@
Commands to manage scene timescale
"""
import asyncio
import logging
import isodate
import talemate.instance as instance
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register

View File

@@ -1,10 +1,6 @@
import asyncio
import logging
from talemate.commands.base import TalemateCommand
from talemate.commands.manager import register
from talemate.instance import get_agent
from talemate.prompts.base import set_default_sectioning_handler
__all__ = [
"CmdTestTTS",

View File

@@ -109,8 +109,6 @@ class CmdUpdateReinforcements(TalemateCommand):
aliases = ["ws_ur"]
async def run(self):
scene = self.scene
world_state = get_agent("world_state")
await world_state.update_reinforcements(force=True)
@@ -234,18 +232,16 @@ class CmdDetermineCharacterDevelopment(TalemateCommand):
scene = self.scene
world_state = get_agent("world_state")
creator = get_agent("creator")
if not len(self.args):
raise ValueError("No character name provided.")
character_name = self.args[0]
character = scene.get_character(character_name)
if not character:
raise ValueError(f"Character {character_name} not found.")
instructions = await world_state.determine_character_development(character)
# updates = await creator.update_character_sheet(character, instructions)
await world_state.determine_character_development(character)
# updates = await creator.update_character_sheet(character, instructions)

View File

@@ -36,7 +36,7 @@ class Manager(Emitter):
aliases[alias] = name.replace("cmd_", "")
return aliases
async def execute(self, cmd, emit_on_unknown:bool = True, state:dict = None):
async def execute(self, cmd, emit_on_unknown: bool = True, state: dict = None):
# commands start with ! and are followed by a command name
cmd = cmd.strip()
cmd_args = ""
@@ -56,7 +56,6 @@ class Manager(Emitter):
for command_cls in self.command_classes:
if command_cls.is_command(cmd_name):
if command_cls.argument_cls:
cmd_kwargs = json.loads(cmd_args_unsplit)
cmd_args = []

View File

@@ -1,4 +1,3 @@
import copy
import datetime
import os
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union, Literal
@@ -6,8 +5,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union,
import pydantic
import structlog
import yaml
from enum import Enum
from pydantic import BaseModel, Field
from pydantic import BaseModel
from typing_extensions import Annotated
from talemate.agents.registry import get_agent_class
@@ -43,7 +41,7 @@ class Client(BaseModel):
rate_limit: Union[int, None] = None
data_format: Literal["json", "yaml"] | None = None
enabled: bool = True
system_prompts: SystemPrompts = SystemPrompts()
preset_group: str | None = None
@@ -165,6 +163,7 @@ class DeepSeekConfig(BaseModel):
class OpenRouterConfig(BaseModel):
api_key: Union[str, None] = None
class RunPodConfig(BaseModel):
api_key: Union[str, None] = None
@@ -202,6 +201,7 @@ class RecentScene(BaseModel):
date: str
cover_image: Union[Asset, None] = None
class EmbeddingFunctionPreset(BaseModel):
embeddings: str = "sentence-transformer"
model: str = "all-MiniLM-L6-v2"
@@ -217,12 +217,11 @@ class EmbeddingFunctionPreset(BaseModel):
client: str | None = None
def generate_chromadb_presets() -> dict[str, EmbeddingFunctionPreset]:
"""
Returns a dict of default embedding presets
"""
return {
"default": EmbeddingFunctionPreset(),
"Alibaba-NLP/gte-base-en-v1.5": EmbeddingFunctionPreset(
@@ -274,25 +273,24 @@ class InferenceParameters(BaseModel):
frequency_penalty: float | None = 0.05
repetition_penalty: float | None = 1.0
repetition_penalty_range: int | None = 1024
xtc_threshold: float | None = 0.1
xtc_probability: float | None = 0.0
dry_multiplier: float | None = 0.0
dry_base: float | None = 1.75
dry_allowed_length: int | None = 2
dry_sequence_breakers: str | None = '"\\n", ":", "\\"", "*"'
smoothing_factor: float | None = 0.0
smoothing_curve: float | None = 1.0
# this determines whether or not it should be persisted
# to the config file
changed: bool = False
class InferencePresets(BaseModel):
analytical: InferenceParameters = InferenceParameters(
temperature=0.7,
presence_penalty=0,
@@ -325,19 +323,26 @@ class InferencePresets(BaseModel):
presence_penalty=0.0,
)
class InferencePresetGroup(BaseModel):
name: str
presets: InferencePresets
class Presets(BaseModel):
inference_defaults: InferencePresets = InferencePresets()
inference: InferencePresets = InferencePresets()
inference_groups: dict[str, InferencePresetGroup] = pydantic.Field(default_factory=dict)
embeddings_defaults: dict[str, EmbeddingFunctionPreset] = pydantic.Field(default_factory=generate_chromadb_presets)
embeddings: dict[str, EmbeddingFunctionPreset] = pydantic.Field(default_factory=generate_chromadb_presets)
inference_groups: dict[str, InferencePresetGroup] = pydantic.Field(
default_factory=dict
)
embeddings_defaults: dict[str, EmbeddingFunctionPreset] = pydantic.Field(
default_factory=generate_chromadb_presets
)
embeddings: dict[str, EmbeddingFunctionPreset] = pydantic.Field(
default_factory=generate_chromadb_presets
)
def gnerate_intro_scenes():
@@ -471,28 +476,28 @@ AnnotatedClient = Annotated[
class HistoryMessageStyle(BaseModel):
italic: bool = False
bold: bool = False
# Leave None for default color
color: str | None = None
color: str | None = None
class HidableHistoryMessageStyle(HistoryMessageStyle):
# certain messages can be hidden, but all messages are shown by default
show: bool = True
class SceneAppearance(BaseModel):
narrator_messages: HistoryMessageStyle = HistoryMessageStyle(italic=True)
character_messages: HistoryMessageStyle = HistoryMessageStyle()
director_messages: HidableHistoryMessageStyle = HidableHistoryMessageStyle()
time_messages: HistoryMessageStyle = HistoryMessageStyle()
context_investigation_messages: HidableHistoryMessageStyle = HidableHistoryMessageStyle()
context_investigation_messages: HidableHistoryMessageStyle = (
HidableHistoryMessageStyle()
)
class Appearance(BaseModel):
scene: SceneAppearance = SceneAppearance()
class Config(BaseModel):
@@ -505,7 +510,7 @@ class Config(BaseModel):
creator: CreatorConfig = CreatorConfig()
openai: OpenAIConfig = OpenAIConfig()
deepseek: DeepSeekConfig = DeepSeekConfig()
mistralai: MistralAIConfig = MistralAIConfig()
@@ -531,11 +536,11 @@ class Config(BaseModel):
recent_scenes: RecentScenes = RecentScenes()
presets: Presets = Presets()
appearance: Appearance = Appearance()
system_prompts: SystemPrompts = SystemPrompts()
class Config:
extra = "ignore"
@@ -595,18 +600,18 @@ def save_config(config, file_path: str = "./config.yaml"):
# we dont want to persist the following, so we drop them:
# - presets.inference_defaults
# - presets.embeddings_defaults
if "inference_defaults" in config["presets"]:
config["presets"].pop("inference_defaults")
if "embeddings_defaults" in config["presets"]:
config["presets"].pop("embeddings_defaults")
# for normal presets we only want to persist if they have changed
for preset_name, preset in list(config["presets"]["inference"].items()):
if not preset.get("changed"):
config["presets"]["inference"].pop(preset_name)
# in inference groups also only keep if changed
for group_name, group in list(config["presets"]["inference_groups"].items()):
for preset_name, preset in list(group["presets"].items()):
@@ -616,7 +621,7 @@ def save_config(config, file_path: str = "./config.yaml"):
# if presets is empty, remove it
if not config["presets"]["inference"]:
config["presets"].pop("inference")
# if system_prompts is empty, remove it
if not config["system_prompts"]:
config.pop("system_prompts")
@@ -624,12 +629,13 @@ def save_config(config, file_path: str = "./config.yaml"):
# set any client preset_group to "" if it references an
# entry that no longer exists in inference_groups
for client in config["clients"].values():
if not client.get("preset_group"):
continue
if client["preset_group"] not in config["presets"].get("inference_groups", {}):
log.warning(f"Client {client['name']} references non-existent preset group {client['preset_group']}, setting to default")
log.warning(
f"Client {client['name']} references non-existent preset group {client['preset_group']}, setting to default"
)
client["preset_group"] = ""
with open(file_path, "w") as file:
@@ -639,7 +645,6 @@ def save_config(config, file_path: str = "./config.yaml"):
def cleanup() -> Config:
log.info("cleaning up config")
config = load_config(as_model=True)

View File

@@ -35,11 +35,12 @@ regeneration_context = ContextVar("regeneration_context", default=None)
active_scene = ContextVar("active_scene", default=None)
interaction = ContextVar("interaction", default=InteractionState())
def handle_generation_cancelled(exc: GenerationCancelled):
# set cancel_requested to False on the active_scene
scene = active_scene.get()
if scene:
scene.cancel_requested = False
@@ -102,6 +103,6 @@ class Interaction:
def assert_active_scene(scene: object):
if not active_scene.get():
raise SceneInactiveError("Scene is not active")
if active_scene.get() != scene:
raise SceneInactiveError("Scene has changed")
raise SceneInactiveError("Scene has changed")

View File

@@ -1,12 +1,12 @@
import talemate.emit.signals as signals
import talemate.emit.signals as signals # noqa: F401
from .base import (
AbortCommand,
Emission,
Emitter,
Receiver,
abort_wait_for_input,
emit,
wait_for_input,
wait_for_input_yesno,
)
AbortCommand, # noqa: F401
Emission, # noqa: F401
Emitter, # noqa: F401
Receiver, # noqa: F401
abort_wait_for_input, # noqa: F401
emit, # noqa: F401
wait_for_input, # noqa: F401
wait_for_input_yesno, # noqa: F401
)

View File

@@ -6,6 +6,7 @@ __all__ = [
handlers = {}
class AsyncSignal:
def __init__(self, name):
self.receivers = []
@@ -21,11 +22,11 @@ class AsyncSignal:
self.receivers.remove(handler)
except ValueError:
pass
async def send(self, emission):
for receiver in self.receivers:
await receiver(emission)
def _register(name: str):
"""

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
import dataclasses
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Callable
import structlog
@@ -123,16 +123,16 @@ async def wait_for_input(
while input_received["message"] is None:
await asyncio.sleep(sleep_time)
interaction_state = interaction.get()
if abort_condition and (await abort_condition()):
raise AbortWaitForInput()
if interaction_state.reset_requested:
interaction_state.reset_requested = False
raise RestartSceneLoop()
if interaction_state.input:
input_received["message"] = interaction_state.input
input_received["interaction"] = interaction_state
@@ -140,7 +140,7 @@ async def wait_for_input(
interaction_state.input = None
interaction_state.from_choice = None
break
handlers["receive_input"].disconnect(input_receiver)
if input_received["message"] == "!abort":
@@ -194,6 +194,6 @@ class Emitter:
def player_message(self, message: str, character: Character):
self.emit("player", message, character=character)
def context_investigation_message(self, message: str):
self.emit("context_investigation", message)

View File

@@ -31,6 +31,7 @@ class ArchiveEvent(Event):
memory_id: str
ts: str = None
@dataclass
class CharacterStateEvent(Event):
state: str
@@ -62,11 +63,13 @@ class GameLoopActorIterEvent(GameLoopBase):
actor: Actor
game_loop: GameLoopEvent
@dataclass
class GameLoopCharacterIterEvent(GameLoopBase):
character: Character
game_loop: GameLoopEvent
@dataclass
class GameLoopNewMessageEvent(GameLoopBase):
message: SceneMessage
@@ -76,6 +79,7 @@ class GameLoopNewMessageEvent(GameLoopBase):
class PlayerTurnStartEvent(Event):
pass
@dataclass
class RegenerateGeneration(Event):
message: "SceneMessage"
@@ -89,4 +93,4 @@ async_signals.register(
"regenerate.msg.context_investigation",
"game_loop_player_character_iter",
"game_loop_ai_character_iter",
)
)

Some files were not shown because too many files have changed in this diff Show More