mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-12-14 18:57:47 +01:00
linting
* precommit * linting * add linting to workflow * ruff.toml added
This commit is contained in:
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@@ -42,6 +42,11 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Run linting
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv run pre-commit run --all-files
|
||||
|
||||
- name: Setup configuration file
|
||||
run: |
|
||||
cp config.example.yaml config.yaml
|
||||
|
||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
fail_fast: false
|
||||
exclude: |
|
||||
(?x)^(
|
||||
tests/data/.*
|
||||
|install-utils/.*
|
||||
)$
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.12.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [ --fix ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
@@ -1,60 +1,63 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
def find_image_references(md_file):
|
||||
"""Find all image references in a markdown file."""
|
||||
with open(md_file, 'r', encoding='utf-8') as f:
|
||||
with open(md_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
pattern = r'!\[.*?\]\((.*?)\)'
|
||||
|
||||
pattern = r"!\[.*?\]\((.*?)\)"
|
||||
matches = re.findall(pattern, content)
|
||||
|
||||
|
||||
cleaned_paths = []
|
||||
for match in matches:
|
||||
path = match.lstrip('/')
|
||||
if 'img/' in path:
|
||||
path = path[path.index('img/') + 4:]
|
||||
path = match.lstrip("/")
|
||||
if "img/" in path:
|
||||
path = path[path.index("img/") + 4 :]
|
||||
# Only keep references to versioned images
|
||||
parts = os.path.normpath(path).split(os.sep)
|
||||
if len(parts) >= 2 and parts[0].replace('.', '').isdigit():
|
||||
if len(parts) >= 2 and parts[0].replace(".", "").isdigit():
|
||||
cleaned_paths.append(path)
|
||||
|
||||
|
||||
return cleaned_paths
|
||||
|
||||
|
||||
def scan_markdown_files(docs_dir):
|
||||
"""Recursively scan all markdown files in the docs directory."""
|
||||
md_files = []
|
||||
for root, _, files in os.walk(docs_dir):
|
||||
for file in files:
|
||||
if file.endswith('.md'):
|
||||
if file.endswith(".md"):
|
||||
md_files.append(os.path.join(root, file))
|
||||
return md_files
|
||||
|
||||
|
||||
def find_all_images(img_dir):
|
||||
"""Find all image files in version subdirectories."""
|
||||
image_files = []
|
||||
for root, _, files in os.walk(img_dir):
|
||||
# Get the relative path from img_dir to current directory
|
||||
rel_dir = os.path.relpath(root, img_dir)
|
||||
|
||||
|
||||
# Skip if we're in the root img directory
|
||||
if rel_dir == '.':
|
||||
if rel_dir == ".":
|
||||
continue
|
||||
|
||||
|
||||
# Check if the immediate parent directory is a version number
|
||||
parent_dir = rel_dir.split(os.sep)[0]
|
||||
if not parent_dir.replace('.', '').isdigit():
|
||||
if not parent_dir.replace(".", "").isdigit():
|
||||
continue
|
||||
|
||||
|
||||
for file in files:
|
||||
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.svg')):
|
||||
if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".svg")):
|
||||
rel_path = os.path.relpath(os.path.join(root, file), img_dir)
|
||||
image_files.append(rel_path)
|
||||
return image_files
|
||||
|
||||
|
||||
def grep_check_image(docs_dir, image_path):
|
||||
"""
|
||||
Check if versioned image is referenced anywhere using grep.
|
||||
@@ -65,33 +68,46 @@ def grep_check_image(docs_dir, image_path):
|
||||
parts = os.path.normpath(image_path).split(os.sep)
|
||||
version = parts[0] # e.g., "0.29.0"
|
||||
filename = parts[-1] # e.g., "world-state-suggestions-2.png"
|
||||
|
||||
|
||||
# For versioned images, require both version and filename to match
|
||||
version_pattern = f"{version}.*{filename}"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['grep', '-r', '-l', version_pattern, docs_dir],
|
||||
["grep", "-r", "-l", version_pattern, docs_dir],
|
||||
capture_output=True,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
if result.stdout.strip():
|
||||
print(f"Found reference to {image_path} with version pattern: {version_pattern}")
|
||||
print(
|
||||
f"Found reference to {image_path} with version pattern: {version_pattern}"
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during grep check for {image_path}: {e}")
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Find and optionally delete unused versioned images in MkDocs project')
|
||||
parser.add_argument('--docs-dir', type=str, required=True, help='Path to the docs directory')
|
||||
parser.add_argument('--img-dir', type=str, required=True, help='Path to the images directory')
|
||||
parser.add_argument('--delete', action='store_true', help='Delete unused images')
|
||||
parser.add_argument('--verbose', action='store_true', help='Show all found references and files')
|
||||
parser.add_argument('--skip-grep', action='store_true', help='Skip the additional grep validation')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Find and optionally delete unused versioned images in MkDocs project"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--docs-dir", type=str, required=True, help="Path to the docs directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img-dir", type=str, required=True, help="Path to the images directory"
|
||||
)
|
||||
parser.add_argument("--delete", action="store_true", help="Delete unused images")
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Show all found references and files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-grep", action="store_true", help="Skip the additional grep validation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert paths to absolute paths
|
||||
@@ -118,7 +134,7 @@ def main():
|
||||
print("\nAll versioned image references found in markdown:")
|
||||
for img in sorted(used_images):
|
||||
print(f"- {img}")
|
||||
|
||||
|
||||
print("\nAll versioned images in directory:")
|
||||
for img in sorted(all_images):
|
||||
print(f"- {img}")
|
||||
@@ -133,9 +149,11 @@ def main():
|
||||
for img in unused_images:
|
||||
if not grep_check_image(docs_dir, img):
|
||||
actually_unused.add(img)
|
||||
|
||||
|
||||
if len(actually_unused) != len(unused_images):
|
||||
print(f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!")
|
||||
print(
|
||||
f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!"
|
||||
)
|
||||
unused_images = actually_unused
|
||||
|
||||
# Report findings
|
||||
@@ -148,7 +166,7 @@ def main():
|
||||
print("\nUnused versioned images:")
|
||||
for img in sorted(unused_images):
|
||||
print(f"- {img}")
|
||||
|
||||
|
||||
if args.delete:
|
||||
print("\nDeleting unused versioned images...")
|
||||
for img in unused_images:
|
||||
@@ -162,5 +180,6 @@ def main():
|
||||
else:
|
||||
print("\nNo unused versioned images found!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -4,12 +4,12 @@ from talemate.events import GameLoopEvent
|
||||
import talemate.emit.async_signals
|
||||
from talemate.emit import emit
|
||||
|
||||
|
||||
@register()
|
||||
class TestAgent(Agent):
|
||||
|
||||
agent_type = "test"
|
||||
verbose_name = "Test"
|
||||
|
||||
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
@@ -20,7 +20,7 @@ class TestAgent(Agent):
|
||||
description="Test",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
@@ -36,7 +36,7 @@ class TestAgent(Agent):
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called on the beginning of every game loop
|
||||
@@ -45,4 +45,8 @@ class TestAgent(Agent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
emit("status", status="info", message="Annoying you with a test message every game loop.")
|
||||
emit(
|
||||
"status",
|
||||
status="info",
|
||||
message="Annoying you with a test message every game loop.",
|
||||
)
|
||||
|
||||
@@ -19,14 +19,17 @@ from talemate.config import Client as BaseClientConfig
|
||||
|
||||
log = structlog.get_logger("talemate.client.runpod_vllm")
|
||||
|
||||
|
||||
class Defaults(pydantic.BaseModel):
|
||||
max_token_length: int = 4096
|
||||
model: str = ""
|
||||
runpod_id: str = ""
|
||||
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
runpod_id: str = ""
|
||||
|
||||
|
||||
@register()
|
||||
class RunPodVLLMClient(ClientBase):
|
||||
client_type = "runpod_vllm"
|
||||
@@ -49,7 +52,6 @@ class RunPodVLLMClient(ClientBase):
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, model=None, runpod_id=None, **kwargs):
|
||||
self.model_name = model
|
||||
self.runpod_id = runpod_id
|
||||
@@ -59,12 +61,10 @@ class RunPodVLLMClient(ClientBase):
|
||||
def experimental(self):
|
||||
return False
|
||||
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
|
||||
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
|
||||
|
||||
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
@@ -88,32 +88,37 @@ class RunPodVLLMClient(ClientBase):
|
||||
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
||||
|
||||
try:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
|
||||
|
||||
run_request = await endpoint.run({
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
|
||||
run_request = await endpoint.run(
|
||||
{
|
||||
"input": {
|
||||
"prompt": prompt,
|
||||
}
|
||||
# "parameters": parameters
|
||||
}
|
||||
#"parameters": parameters
|
||||
})
|
||||
|
||||
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
|
||||
)
|
||||
|
||||
while (await run_request.status()) not in [
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
]:
|
||||
status = await run_request.status()
|
||||
log.debug("generate", status=status)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
status = await run_request.status()
|
||||
|
||||
|
||||
log.debug("generate", status=status)
|
||||
|
||||
|
||||
response = await run_request.output()
|
||||
|
||||
|
||||
log.debug("generate", response=response)
|
||||
|
||||
|
||||
return response["choices"][0]["tokens"][0]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit(
|
||||
|
||||
@@ -9,6 +9,7 @@ class Defaults(pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:1234"
|
||||
max_token_length: int = 4096
|
||||
|
||||
|
||||
@register()
|
||||
class TestClient(ClientBase):
|
||||
client_type = "test"
|
||||
@@ -22,14 +23,13 @@ class TestClient(ClientBase):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
|
||||
"""
|
||||
Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients.
|
||||
|
||||
|
||||
This method is called before the prompt is sent to the client, and it allows the client to remove
|
||||
any parameters that it doesn't support.
|
||||
"""
|
||||
|
||||
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
keys = list(parameters.keys())
|
||||
@@ -41,11 +41,10 @@ class TestClient(ClientBase):
|
||||
del parameters[key]
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
"""
|
||||
This should return the name of the model that is being used.
|
||||
"""
|
||||
|
||||
|
||||
return "Mock test model"
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import HTMLResponse, FileResponse
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
@@ -14,6 +14,7 @@ app = FastAPI()
|
||||
# Serve static files, but exclude the root path
|
||||
app.mount("/", StaticFiles(directory=dist_dir, html=True), name="static")
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def serve_root():
|
||||
index_path = os.path.join(dist_dir, "index.html")
|
||||
@@ -24,5 +25,6 @@ async def serve_root():
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="index.html not found")
|
||||
|
||||
|
||||
# This is the ASGI application
|
||||
application = app
|
||||
application = app
|
||||
|
||||
@@ -65,6 +65,7 @@ dev = [
|
||||
"mkdocs-material>=9.5.27",
|
||||
"mkdocs-awesome-pages-plugin>=2.9.2",
|
||||
"mkdocs-glightbox>=0.4.0",
|
||||
"pre-commit>=2.13",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -103,4 +104,4 @@ include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 88
|
||||
line_length = 88
|
||||
|
||||
5
ruff.toml
Normal file
5
ruff.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
[lint]
|
||||
# Disable automatic fix for unused imports (`F401`). We check these manually.
|
||||
unfixable = ["F401"]
|
||||
# Ignore E402
|
||||
extend-ignore = ["E402"]
|
||||
@@ -1,111 +1,112 @@
|
||||
|
||||
def game(TM):
|
||||
|
||||
MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions"
|
||||
|
||||
MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\""
|
||||
|
||||
|
||||
MSG_HELP = 'Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating "Computer," followed by an instruction. For example ... "Computer, i want to experience being on a derelict spaceship."'
|
||||
|
||||
PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
|
||||
PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice. Remind the user that this is an old version of the simulation suite and they should check out version two for a more advanced experience."
|
||||
|
||||
CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
|
||||
|
||||
|
||||
CTX_PIN_UNAWARE = (
|
||||
"Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION."
|
||||
)
|
||||
|
||||
AUTO_NARRATE_INTERVAL = 10
|
||||
|
||||
def parse_sim_call_arguments(call:str) -> str:
|
||||
|
||||
def parse_sim_call_arguments(call: str) -> str:
|
||||
"""
|
||||
Returns the value between the parentheses of a simulation call
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
call = 'change_environment("a house")'
|
||||
|
||||
|
||||
parse_sim_call_arguments(call) -> "a house"
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
return call.split("(", 1)[1].split(")")[0]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class SimulationSuite:
|
||||
def __init__(self):
|
||||
|
||||
"""
|
||||
This is initialized at the beginning of each round of the simulation suite
|
||||
"""
|
||||
|
||||
|
||||
# do we update the world state at the end of the round
|
||||
self.update_world_state = False
|
||||
self.simulation_reset = False
|
||||
|
||||
|
||||
# will keep track of any npcs added during the current round
|
||||
self.added_npcs = []
|
||||
|
||||
|
||||
TM.log.debug("SIMULATION SUITE INIT!", scene=TM.scene)
|
||||
self.player_message = TM.scene.last_player_message
|
||||
self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1)
|
||||
|
||||
self.last_processed_call = TM.game_state.get_var(
|
||||
"instr.lastprocessed_call", -1
|
||||
)
|
||||
|
||||
# determine whether the player / user input is an instruction
|
||||
# to the simulation computer
|
||||
#
|
||||
#
|
||||
# we do this by checking if the message starts with "Computer,"
|
||||
self.player_message_is_instruction = (
|
||||
self.player_message and
|
||||
self.player_message.raw.lower().startswith("computer") and
|
||||
not self.player_message.hidden and
|
||||
not self.last_processed_call > self.player_message.id
|
||||
self.player_message
|
||||
and self.player_message.raw.lower().startswith("computer")
|
||||
and not self.player_message.hidden
|
||||
and not self.last_processed_call > self.player_message.id
|
||||
)
|
||||
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Main entry point for the simulation suite
|
||||
"""
|
||||
|
||||
|
||||
if not TM.game_state.has_var("instr.simulation_stopped"):
|
||||
# simulation is still running
|
||||
self.simulation()
|
||||
|
||||
|
||||
self.finalize_round()
|
||||
|
||||
|
||||
def simulation(self):
|
||||
"""
|
||||
Simulation suite logic
|
||||
"""
|
||||
|
||||
|
||||
if not TM.game_state.has_var("instr.simulation_started"):
|
||||
self.startup()
|
||||
else:
|
||||
self.simulation_calls()
|
||||
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state(force=True)
|
||||
|
||||
|
||||
def startup(self):
|
||||
|
||||
"""
|
||||
Scene startup logic
|
||||
"""
|
||||
|
||||
|
||||
# we are at the beginning of the simulation
|
||||
TM.signals.status("busy", "Simulation suite powering up.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite powering up.", as_scene_message=True
|
||||
)
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
|
||||
|
||||
# add narration for the introduction
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_STARTUP,
|
||||
emit_message=False
|
||||
emit_message=False,
|
||||
)
|
||||
|
||||
|
||||
# add narration for the instructions on how to interact with the simulation
|
||||
# this is a passthrough since we don't want the AI to paraphrase this
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="passthrough",
|
||||
narration=MSG_HELP
|
||||
action_name="passthrough", narration=MSG_HELP
|
||||
)
|
||||
|
||||
|
||||
# create a world state entry letting the AI know that characters
|
||||
# interacting in the simulation are not aware of the computer or the simulation
|
||||
TM.agents.world_state.save_world_entry(
|
||||
@@ -113,37 +114,43 @@ def game(TM):
|
||||
text=CTX_PIN_UNAWARE,
|
||||
meta={},
|
||||
# this should always be pinned
|
||||
pin=True
|
||||
pin=True,
|
||||
)
|
||||
|
||||
|
||||
# set flag that we have started the simulation
|
||||
TM.game_state.set_var("instr.simulation_started", "yes", commit=False)
|
||||
|
||||
|
||||
# signal to the UX that the simulation suite is ready
|
||||
TM.signals.status("success", "Simulation suite ready", as_scene_message=True)
|
||||
|
||||
TM.signals.status(
|
||||
"success", "Simulation suite ready", as_scene_message=True
|
||||
)
|
||||
|
||||
# we want to update the world state at the end of the round
|
||||
self.update_world_state = True
|
||||
|
||||
|
||||
def simulation_calls(self):
|
||||
"""
|
||||
Calls the simulation suite main prompt to determine the appropriate
|
||||
simulation calls
|
||||
"""
|
||||
|
||||
|
||||
# we only process instructions that are not hidden and are not the last processed call
|
||||
if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call:
|
||||
if (
|
||||
not self.player_message_is_instruction
|
||||
or self.player_message.id == self.last_processed_call
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
# First instruction?
|
||||
if not TM.game_state.has_var("instr.has_issued_instructions"):
|
||||
|
||||
# determine the context of the simulation
|
||||
context_context = TM.agents.creator.determine_content_context_for_description(
|
||||
description=self.player_message.raw,
|
||||
context_context = (
|
||||
TM.agents.creator.determine_content_context_for_description(
|
||||
description=self.player_message.raw,
|
||||
)
|
||||
)
|
||||
TM.scene.set_content_context(context_context)
|
||||
|
||||
|
||||
# Render the `computer` template and send it to the LLM for processing
|
||||
# The LLM will return a list of calls that the simulation suite will process
|
||||
# The calls are pseudo code that the simulation suite will interpret and execute
|
||||
@@ -153,90 +160,98 @@ def game(TM):
|
||||
player_instruction=self.player_message.raw,
|
||||
scene=TM.scene,
|
||||
)
|
||||
|
||||
|
||||
self.calls = calls = calls.split("\n")
|
||||
|
||||
|
||||
calls = self.prepare_calls(calls)
|
||||
|
||||
|
||||
TM.log.debug("SIMULATION SUITE CALLS", callse=calls)
|
||||
|
||||
|
||||
# calls that are processed
|
||||
processed = []
|
||||
|
||||
|
||||
for call in calls:
|
||||
processed_call = self.process_call(call)
|
||||
if processed_call:
|
||||
processed.append(processed_call)
|
||||
|
||||
|
||||
if processed:
|
||||
TM.log.debug("SIMULATION SUITE CALLS", calls=processed)
|
||||
TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False)
|
||||
|
||||
TM.signals.status("busy", "Simulation suite altering environment.", as_scene_message=True)
|
||||
TM.game_state.set_var(
|
||||
"instr.has_issued_instructions", "yes", commit=False
|
||||
)
|
||||
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite altering environment.", as_scene_message=True
|
||||
)
|
||||
compiled = "\n".join(processed)
|
||||
|
||||
|
||||
if not self.simulation_reset and compiled:
|
||||
|
||||
# send the compiled calls to the narrator to generate a narrative based
|
||||
# on them
|
||||
narration = TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.",
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
# on the first narration we update the scene description and remove any mention of the computer
|
||||
# or the simulation from the previous narration
|
||||
is_initial_narration = TM.game_state.get_var("instr.intro_narration", False)
|
||||
is_initial_narration = TM.game_state.get_var(
|
||||
"instr.intro_narration", False
|
||||
)
|
||||
if not is_initial_narration:
|
||||
TM.scene.set_description(narration.raw)
|
||||
TM.scene.set_intro(narration.raw)
|
||||
TM.log.debug("SIMULATION SUITE: initial narration", intro=narration.raw)
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: initial narration", intro=narration.raw
|
||||
)
|
||||
TM.scene.pop_history(typ="narrator", all=True, reverse=True)
|
||||
TM.scene.pop_history(typ="director", all=True, reverse=True)
|
||||
TM.game_state.set_var("instr.intro_narration", True, commit=False)
|
||||
|
||||
|
||||
self.update_world_state = True
|
||||
|
||||
|
||||
self.set_simulation_title(compiled)
|
||||
|
||||
|
||||
def set_simulation_title(self, compiled_calls):
|
||||
|
||||
"""
|
||||
Generates a fitting title for the simulation based on the user's instructions
|
||||
"""
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls)
|
||||
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: set simulation title",
|
||||
name=TM.scene.title,
|
||||
compiled_calls=compiled_calls,
|
||||
)
|
||||
|
||||
if not compiled_calls:
|
||||
return
|
||||
|
||||
|
||||
if TM.scene.title != "Simulation Suite":
|
||||
# name already changed, no need to do it again
|
||||
return
|
||||
|
||||
|
||||
title = TM.agents.creator.contextual_generate_from_args(
|
||||
"scene:simulation title",
|
||||
"Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.",
|
||||
length=75
|
||||
length=75,
|
||||
)
|
||||
|
||||
|
||||
title = title.strip('"').strip()
|
||||
|
||||
|
||||
TM.scene.set_title(title)
|
||||
|
||||
|
||||
def prepare_calls(self, calls):
|
||||
"""
|
||||
Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both
|
||||
found, ensure that the `set_player_name` call is processed first by moving it in front of the
|
||||
`set_player_persona` call.
|
||||
"""
|
||||
|
||||
|
||||
set_player_name_call_exists = -1
|
||||
set_player_persona_call_exists = -1
|
||||
|
||||
|
||||
i = 0
|
||||
for call in calls:
|
||||
if "set_player_name" in call:
|
||||
@@ -244,351 +259,445 @@ def game(TM):
|
||||
elif "set_player_persona" in call:
|
||||
set_player_persona_call_exists = i
|
||||
i = i + 1
|
||||
|
||||
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
|
||||
|
||||
if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1:
|
||||
if set_player_name_call_exists > set_player_persona_call_exists:
|
||||
calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists))
|
||||
TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls)
|
||||
|
||||
calls.insert(
|
||||
set_player_persona_call_exists,
|
||||
calls.pop(set_player_name_call_exists),
|
||||
)
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: prepare calls - moved set_player_persona call",
|
||||
calls=calls,
|
||||
)
|
||||
|
||||
return calls
|
||||
|
||||
def process_call(self, call:str) -> str:
|
||||
def process_call(self, call: str) -> str:
|
||||
"""
|
||||
Processes a simulation call
|
||||
|
||||
|
||||
Simulation alls are pseudo functions that are called by the simulation suite
|
||||
|
||||
|
||||
We grab the function name by splitting against ( and taking the first element
|
||||
if the SimulationSuite has a method with the name _call_{function_name} then we call it
|
||||
|
||||
|
||||
if a function name could be found but we do not have a method to call we dont do anything
|
||||
but we still return it as procssed as the AI can still interpret it as something later on
|
||||
"""
|
||||
|
||||
|
||||
if "(" not in call:
|
||||
return None
|
||||
|
||||
|
||||
function_name = call.split("(")[0]
|
||||
|
||||
|
||||
if hasattr(self, f"call_{function_name}"):
|
||||
TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name)
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE CALL", call=call, function_name=function_name
|
||||
)
|
||||
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
|
||||
return getattr(self, f"call_{function_name}")(call, inject)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_set_simulation_goal(self, call:str, inject:str) -> str:
|
||||
def call_set_simulation_goal(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
Set's the simulation goal as a permanent pin
|
||||
"""
|
||||
TM.signals.status("busy", "Simulation suite setting goal.", as_scene_message=True)
|
||||
TM.agents.world_state.save_world_entry(
|
||||
entry_id="sim.goal",
|
||||
text=self.player_message.raw,
|
||||
meta={},
|
||||
pin=True
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite setting goal.", as_scene_message=True
|
||||
)
|
||||
|
||||
TM.agents.world_state.save_world_entry(
|
||||
entry_id="sim.goal", text=self.player_message.raw, meta={}, pin=True
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer sets the goal for the simulation.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_change_environment(self, call:str, inject:str) -> str:
|
||||
|
||||
def call_change_environment(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
Simulation changes the environment, this is entirely interpreted by the AI
|
||||
and we dont need to do any logic on our end, so we just return the call
|
||||
"""
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer changes the environment of the simulation."
|
||||
action_description="The computer changes the environment of the simulation.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_answer_question(self, call:str, inject:str) -> str:
|
||||
|
||||
def call_answer_question(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
The player asked the simulation a query, we need to process this and have
|
||||
the AI produce an answer
|
||||
"""
|
||||
|
||||
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.",
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
def call_set_player_persona(self, call:str, inject:str) -> str:
|
||||
|
||||
|
||||
def call_set_player_persona(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
The simulation suite is altering the player persona
|
||||
"""
|
||||
|
||||
|
||||
player_character = TM.scene.get_player_character()
|
||||
|
||||
TM.signals.status("busy", "Simulation suite altering user persona.", as_scene_message=True)
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=player_character.name, text=inject, alteration_instructions=self.player_message.raw
|
||||
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite altering user persona.", as_scene_message=True
|
||||
)
|
||||
TM.scene.set_character_attributes(player_character.name, character_attributes)
|
||||
|
||||
character_description = TM.agents.creator.determine_character_description(player_character.name)
|
||||
|
||||
TM.scene.set_character_description(player_character.name, character_description)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description)
|
||||
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=player_character.name,
|
||||
text=inject,
|
||||
alteration_instructions=self.player_message.raw,
|
||||
)
|
||||
TM.scene.set_character_attributes(
|
||||
player_character.name, character_attributes
|
||||
)
|
||||
|
||||
character_description = TM.agents.creator.determine_character_description(
|
||||
player_character.name
|
||||
)
|
||||
|
||||
TM.scene.set_character_description(
|
||||
player_character.name, character_description
|
||||
)
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: transform player",
|
||||
attributes=character_attributes,
|
||||
description=character_description,
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer transforms the player persona."
|
||||
action_description="The computer transforms the player persona.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_set_player_name(self, call:str, inject:str) -> str:
|
||||
|
||||
|
||||
def call_set_player_name(self, call: str, inject: str) -> str:
|
||||
"""
|
||||
The simulation suite is altering the player name
|
||||
"""
|
||||
player_character = TM.scene.get_player_character()
|
||||
|
||||
TM.signals.status("busy", "Simulation suite adjusting user identity.", as_scene_message=True)
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.")
|
||||
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
"Simulation suite adjusting user identity.",
|
||||
as_scene_message=True,
|
||||
)
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits."
|
||||
)
|
||||
TM.log.debug("SIMULATION SUITE: player name", character_name=character_name)
|
||||
if character_name != player_character.name:
|
||||
TM.scene.set_character_name(player_character.name, character_name)
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer changes the player's identity to {character_name}."
|
||||
action_description=f"The computer changes the player's identity to {character_name}.",
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
def call_add_ai_character(self, call:str, inject:str) -> str:
|
||||
|
||||
return call
|
||||
|
||||
def call_add_ai_character(self, call: str, inject: str) -> str:
|
||||
# sometimes the AI will call this function an pass an inanimate object as the parameter
|
||||
# we need to determine if this is the case and just ignore it
|
||||
is_inanimate = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call)
|
||||
|
||||
is_inanimate = TM.agents.world_state.answer_query_true_or_false(
|
||||
f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)",
|
||||
call,
|
||||
)
|
||||
|
||||
if is_inanimate:
|
||||
TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call)
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped",
|
||||
call=call,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# sometimes the AI will ask if the function adds a group of characters, we need to
|
||||
# determine if this is the case
|
||||
adds_group = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add MULTIPLE ai characters?", call)
|
||||
|
||||
adds_group = TM.agents.world_state.answer_query_true_or_false(
|
||||
f"does the function `{call}` add MULTIPLE ai characters?", call
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group)
|
||||
|
||||
TM.signals.status("busy", "Simulation suite adding character.", as_scene_message=True)
|
||||
|
||||
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite adding character.", as_scene_message=True
|
||||
)
|
||||
|
||||
if not adds_group:
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.")
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name."
|
||||
)
|
||||
else:
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.",
|
||||
group=True,
|
||||
)
|
||||
|
||||
# sometimes add_ai_character and change_ai_character are called in the same instruction targeting
|
||||
# the same character, if this happens we need to combine into a single add_ai_character call
|
||||
|
||||
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls))
|
||||
|
||||
|
||||
has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(
|
||||
f"Are there any calls to `change_ai_character` in the instruction for {character_name}?",
|
||||
"\n".join(self.calls),
|
||||
)
|
||||
|
||||
if has_change_ai_character_call:
|
||||
|
||||
combined_arg = TM.prompt.request(
|
||||
"combine-add-and-alter-ai-character",
|
||||
dedupe_enabled=False,
|
||||
calls="\n".join(self.calls),
|
||||
character_name=character_name,
|
||||
scene=TM.scene,
|
||||
).replace("COMBINED ARGUMENT:", "").strip()
|
||||
|
||||
combined_arg = (
|
||||
TM.prompt.request(
|
||||
"combine-add-and-alter-ai-character",
|
||||
dedupe_enabled=False,
|
||||
calls="\n".join(self.calls),
|
||||
character_name=character_name,
|
||||
scene=TM.scene,
|
||||
)
|
||||
.replace("COMBINED ARGUMENT:", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
call = f"add_ai_character({combined_arg})"
|
||||
inject = f"The computer executes the function `{call}`"
|
||||
|
||||
|
||||
TM.signals.status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True)
|
||||
|
||||
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
f"Simulation suite adding character: {character_name}",
|
||||
as_scene_message=True,
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: add npc", name=character_name)
|
||||
|
||||
npc = TM.agents.director.persist_character(character_name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False)
|
||||
|
||||
|
||||
npc = TM.agents.director.persist_character(
|
||||
character_name=character_name,
|
||||
content=self.player_message.raw + f"\n\n{inject}",
|
||||
determine_name=False,
|
||||
)
|
||||
|
||||
self.added_npcs.append(npc.name)
|
||||
|
||||
|
||||
TM.agents.world_state.add_detail_reinforcement(
|
||||
character_name=npc.name,
|
||||
detail="Goal",
|
||||
instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation",
|
||||
interval=25,
|
||||
run_immediately=True
|
||||
run_immediately=True,
|
||||
)
|
||||
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: added npc", npc=npc)
|
||||
|
||||
|
||||
TM.agents.visual.generate_character_portrait(character_name=npc.name)
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer adds {npc.name} to the simulation."
|
||||
action_description=f"The computer adds {npc.name} to the simulation.",
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
return call
|
||||
|
||||
####
|
||||
|
||||
def call_remove_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.signals.status("busy", "Simulation suite removing character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names)
|
||||
|
||||
def call_remove_ai_character(self, call: str, inject: str) -> str:
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite removing character.", as_scene_message=True
|
||||
)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character being removed?",
|
||||
allowed_names=TM.scene.npc_character_names,
|
||||
)
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
|
||||
if npc:
|
||||
TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name)
|
||||
TM.agents.world_state.deactivate_character(action_name="deactivate_character", character_name=npc.name)
|
||||
|
||||
TM.agents.world_state.deactivate_character(
|
||||
action_name="deactivate_character", character_name=npc.name
|
||||
)
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer removes {npc.name} from the simulation."
|
||||
action_description=f"The computer removes {npc.name} from the simulation.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_change_ai_character(self, call:str, inject:str) -> str:
|
||||
TM.signals.status("busy", "Simulation suite altering character.", as_scene_message=True)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names)
|
||||
|
||||
def call_change_ai_character(self, call: str, inject: str) -> str:
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite altering character.", as_scene_message=True
|
||||
)
|
||||
|
||||
character_name = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?",
|
||||
allowed_names=TM.scene.npc_character_names,
|
||||
)
|
||||
|
||||
if character_name in self.added_npcs:
|
||||
# we dont want to change the character if it was just added
|
||||
return
|
||||
|
||||
character_name_after = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?")
|
||||
|
||||
|
||||
character_name_after = TM.agents.creator.determine_character_name(
|
||||
instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?"
|
||||
)
|
||||
|
||||
npc = TM.scene.get_character(character_name)
|
||||
|
||||
|
||||
if npc:
|
||||
TM.signals.status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True)
|
||||
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
f"Changing {character_name} -> {character_name_after}",
|
||||
as_scene_message=True,
|
||||
)
|
||||
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", npc=npc)
|
||||
|
||||
|
||||
character_attributes = TM.agents.world_state.extract_character_sheet(
|
||||
name=npc.name,
|
||||
text=inject,
|
||||
alteration_instructions=self.player_message.raw
|
||||
alteration_instructions=self.player_message.raw,
|
||||
)
|
||||
|
||||
|
||||
TM.scene.set_character_attributes(npc.name, character_attributes)
|
||||
character_description = TM.agents.creator.determine_character_description(npc.name)
|
||||
|
||||
character_description = (
|
||||
TM.agents.creator.determine_character_description(npc.name)
|
||||
)
|
||||
|
||||
TM.scene.set_character_description(npc.name, character_description)
|
||||
TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description)
|
||||
|
||||
TM.log.debug(
|
||||
"SIMULATION SUITE: transform npc",
|
||||
attributes=character_attributes,
|
||||
description=character_description,
|
||||
)
|
||||
|
||||
if character_name_after != character_name:
|
||||
TM.scene.set_character_name(npc.name, character_name_after)
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description=f"The computer transforms {npc.name}."
|
||||
action_description=f"The computer transforms {npc.name}.",
|
||||
)
|
||||
|
||||
|
||||
return call
|
||||
|
||||
def call_end_simulation(self, call:str, inject:str) -> str:
|
||||
|
||||
|
||||
def call_end_simulation(self, call: str, inject: str) -> str:
|
||||
player_character = TM.scene.get_player_character()
|
||||
|
||||
explicit_command = TM.agents.world_state.answer_query_true_or_false("has the player explicitly asked to end the simulation?", self.player_message.raw)
|
||||
|
||||
|
||||
explicit_command = TM.agents.world_state.answer_query_true_or_false(
|
||||
"has the player explicitly asked to end the simulation?",
|
||||
self.player_message.raw,
|
||||
)
|
||||
|
||||
if explicit_command:
|
||||
TM.signals.status("busy", "Simulation suite ending current simulation.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"busy",
|
||||
"Simulation suite ending current simulation.",
|
||||
as_scene_message=True,
|
||||
)
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names)}. The player is also transformed back to their normal, non-descript persona as the form of {player_character.name} ceases to exist.",
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
TM.scene.restore()
|
||||
|
||||
self.simulation_reset = True
|
||||
|
||||
|
||||
TM.game_state.unset_var("instr.has_issued_instructions")
|
||||
TM.game_state.unset_var("instr.lastprocessed_call")
|
||||
TM.game_state.unset_var("instr.simulation_started")
|
||||
|
||||
|
||||
TM.agents.director.log_action(
|
||||
action=parse_sim_call_arguments(call),
|
||||
action_description="The computer ends the simulation."
|
||||
action_description="The computer ends the simulation.",
|
||||
)
|
||||
|
||||
|
||||
def finalize_round(self):
|
||||
|
||||
# track rounds
|
||||
rounds = TM.game_state.get_var("instr.rounds", 0)
|
||||
|
||||
|
||||
# increase rounds
|
||||
TM.game_state.set_var("instr.rounds", rounds + 1, commit=False)
|
||||
|
||||
has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions")
|
||||
|
||||
|
||||
has_issued_instructions = TM.game_state.has_var(
|
||||
"instr.has_issued_instructions"
|
||||
)
|
||||
|
||||
if self.update_world_state:
|
||||
self.run_update_world_state()
|
||||
|
||||
|
||||
if self.player_message_is_instruction:
|
||||
TM.scene.hide_message(self.player_message.id)
|
||||
TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False)
|
||||
TM.signals.status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True)
|
||||
|
||||
TM.game_state.set_var(
|
||||
"instr.lastprocessed_call", self.player_message.id, commit=False
|
||||
)
|
||||
TM.signals.status(
|
||||
"success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True
|
||||
)
|
||||
|
||||
elif self.player_message and not has_issued_instructions:
|
||||
# simulation started, player message is NOT an instruction, and player has not given
|
||||
# any instructions
|
||||
self.guide_player()
|
||||
|
||||
elif self.player_message and not TM.scene.npc_character_names:
|
||||
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
|
||||
# simulation started, player message is NOT an instruction, but there are no npcs to interact with
|
||||
self.narrate_round()
|
||||
|
||||
elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names and has_issued_instructions:
|
||||
|
||||
elif (
|
||||
rounds % AUTO_NARRATE_INTERVAL == 0
|
||||
and rounds
|
||||
and TM.scene.npc_character_names
|
||||
and has_issued_instructions
|
||||
):
|
||||
# every N rounds, narrate the round
|
||||
self.narrate_round()
|
||||
|
||||
|
||||
def guide_player(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="paraphrase",
|
||||
narration=MSG_HELP,
|
||||
emit_message=True
|
||||
action_name="paraphrase", narration=MSG_HELP, emit_message=True
|
||||
)
|
||||
|
||||
|
||||
def narrate_round(self):
|
||||
TM.agents.narrator.action_to_narration(
|
||||
action_name="progress_story",
|
||||
narrative_direction=PROMPT_NARRATE_ROUND,
|
||||
emit_message=True
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
def run_update_world_state(self, force=False):
|
||||
TM.log.debug("SIMULATION SUITE: update world state", force=force)
|
||||
TM.signals.status("busy", "Simulation suite updating world state.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"busy", "Simulation suite updating world state.", as_scene_message=True
|
||||
)
|
||||
TM.agents.world_state.update_world_state(force=force)
|
||||
TM.signals.status("success", "Simulation suite updated world state.", as_scene_message=True)
|
||||
TM.signals.status(
|
||||
"success",
|
||||
"Simulation suite updated world state.",
|
||||
as_scene_message=True,
|
||||
)
|
||||
|
||||
SimulationSuite().run()
|
||||
|
||||
|
||||
|
||||
def on_generation_cancelled(TM, exc):
|
||||
|
||||
"""
|
||||
Called when user pressed the cancel button during the simulation suite
|
||||
loop.
|
||||
"""
|
||||
|
||||
TM.signals.status("success", "Simulation suite instructions cancelled", as_scene_message=True)
|
||||
|
||||
TM.signals.status(
|
||||
"success", "Simulation suite instructions cancelled", as_scene_message=True
|
||||
)
|
||||
rounds = TM.game_state.get_var("instr.rounds", 0)
|
||||
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)
|
||||
TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .tale_mate import *
|
||||
from .tale_mate import * # noqa: F401, F403
|
||||
|
||||
from .version import VERSION
|
||||
|
||||
__version__ = VERSION
|
||||
__version__ = VERSION
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from .base import Agent
|
||||
from .conversation import ConversationAgent
|
||||
from .creator import CreatorAgent
|
||||
from .director import DirectorAgent
|
||||
from .editor import EditorAgent
|
||||
from .memory import ChromaDBMemoryAgent, MemoryAgent
|
||||
from .narrator import NarratorAgent
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register
|
||||
from .summarize import SummarizeAgent
|
||||
from .tts import TTSAgent
|
||||
from .visual import VisualAgent
|
||||
from .world_state import WorldStateAgent
|
||||
from .base import Agent # noqa: F401
|
||||
from .conversation import ConversationAgent # noqa: F401
|
||||
from .creator import CreatorAgent # noqa: F401
|
||||
from .director import DirectorAgent # noqa: F401
|
||||
from .editor import EditorAgent # noqa: F401
|
||||
from .memory import ChromaDBMemoryAgent, MemoryAgent # noqa: F401
|
||||
from .narrator import NarratorAgent # noqa: F401
|
||||
from .registry import AGENT_CLASSES, get_agent_class, register # noqa: F401
|
||||
from .summarize import SummarizeAgent # noqa: F401
|
||||
from .tts import TTSAgent # noqa: F401
|
||||
from .visual import VisualAgent # noqa: F401
|
||||
from .world_state import WorldStateAgent # noqa: F401
|
||||
|
||||
@@ -6,11 +6,10 @@ from inspect import signature
|
||||
import re
|
||||
from abc import ABC
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
from typing import Callable, Union
|
||||
import uuid
|
||||
import pydantic
|
||||
import structlog
|
||||
from blinker import signal
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
@@ -39,6 +38,7 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.agents.base")
|
||||
|
||||
|
||||
class AgentActionConditional(pydantic.BaseModel):
|
||||
attribute: str
|
||||
value: int | float | str | bool | list[int | float | str | bool] | None = None
|
||||
@@ -48,6 +48,7 @@ class AgentActionNote(pydantic.BaseModel):
|
||||
type: str
|
||||
text: str
|
||||
|
||||
|
||||
class AgentActionConfig(pydantic.BaseModel):
|
||||
type: str
|
||||
label: str
|
||||
@@ -65,7 +66,7 @@ class AgentActionConfig(pydantic.BaseModel):
|
||||
condition: Union[AgentActionConditional, None] = None
|
||||
title: Union[str, None] = None
|
||||
value_migration: Union[Callable, None] = pydantic.Field(default=None, exclude=True)
|
||||
|
||||
|
||||
note_on_value: dict[str, AgentActionNote] = pydantic.Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
@@ -85,37 +86,37 @@ class AgentAction(pydantic.BaseModel):
|
||||
quick_toggle: bool = False
|
||||
experimental: bool = False
|
||||
|
||||
|
||||
class AgentDetail(pydantic.BaseModel):
|
||||
value: Union[str, None] = None
|
||||
description: Union[str, None] = None
|
||||
icon: Union[str, None] = None
|
||||
color: str = "grey"
|
||||
|
||||
|
||||
class DynamicInstruction(pydantic.BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
f"<|SECTION:{self.title}|>",
|
||||
self.content,
|
||||
"<|CLOSE_SECTION|>"
|
||||
]
|
||||
[f"<|SECTION:{self.title}|>", self.content, "<|CLOSE_SECTION|>"]
|
||||
)
|
||||
|
||||
|
||||
def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = None) -> dict:
|
||||
|
||||
def args_and_kwargs_to_dict(
|
||||
fn, args: list, kwargs: dict, filter: list[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Takes a list of arguments and a dict of keyword arguments and returns
|
||||
a dict mapping parameter names to their values.
|
||||
|
||||
|
||||
Args:
|
||||
fn: The function whose parameters we want to map
|
||||
args: List of positional arguments
|
||||
kwargs: Dictionary of keyword arguments
|
||||
filter: List of parameter names to include in the result, if None all parameters are included
|
||||
|
||||
|
||||
Returns:
|
||||
Dict mapping parameter names to their values
|
||||
"""
|
||||
@@ -124,34 +125,36 @@ def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = Non
|
||||
bound_args.apply_defaults()
|
||||
rv = dict(bound_args.arguments)
|
||||
rv.pop("self", None)
|
||||
|
||||
|
||||
if filter:
|
||||
for key in list(rv.keys()):
|
||||
if key not in filter:
|
||||
rv.pop(key)
|
||||
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
class store_context_state:
|
||||
"""
|
||||
Flag to store a function's arguments in the agent's context state.
|
||||
|
||||
|
||||
Any arguments passed to the function will be stored in the agent's context
|
||||
|
||||
|
||||
If no arguments are passed, all arguments will be stored.
|
||||
|
||||
|
||||
Keyword arguments can be passed to store additional values in the context state.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def __call__(self, fn):
|
||||
fn.store_context_state = self.args
|
||||
fn.store_context_state_kwargs = self.kwargs
|
||||
return fn
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
decorator that emits the agent status as processing while the function
|
||||
@@ -165,31 +168,38 @@ def set_processing(fn):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
with ClientContext():
|
||||
scene = active_scene.get()
|
||||
|
||||
|
||||
if scene:
|
||||
scene.continue_actions()
|
||||
|
||||
|
||||
if getattr(scene, "config", None):
|
||||
set_client_context_attribute("app_config_system_prompts", scene.config.get("system_prompts", {}))
|
||||
|
||||
set_client_context_attribute(
|
||||
"app_config_system_prompts", scene.config.get("system_prompts", {})
|
||||
)
|
||||
|
||||
with ActiveAgent(self, fn, args, kwargs) as active_agent_context:
|
||||
try:
|
||||
await self.emit_status(processing=True)
|
||||
|
||||
|
||||
# Now pass the complete args list
|
||||
if getattr(fn, "store_context_state", None) is not None:
|
||||
all_args = args_and_kwargs_to_dict(
|
||||
fn, [self] + list(args), kwargs, getattr(fn, "store_context_state", [])
|
||||
fn,
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
getattr(fn, "store_context_state", []),
|
||||
)
|
||||
if getattr(fn, "store_context_state_kwargs", None) is not None:
|
||||
all_args.update(getattr(fn, "store_context_state_kwargs", {}))
|
||||
|
||||
all_args.update(
|
||||
getattr(fn, "store_context_state_kwargs", {})
|
||||
)
|
||||
|
||||
all_args[f"fn_{fn.__name__}"] = True
|
||||
|
||||
|
||||
active_agent_context.state_params = all_args
|
||||
|
||||
|
||||
self.set_context_states(**all_args)
|
||||
|
||||
|
||||
return await fn(self, *args, **kwargs)
|
||||
finally:
|
||||
try:
|
||||
@@ -215,14 +225,16 @@ class Agent(ABC):
|
||||
websocket_handler = None
|
||||
essential = True
|
||||
ready_check_error = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls, actions: dict[str, AgentAction] | None = None) -> dict[str, AgentAction]:
|
||||
def init_actions(
|
||||
cls, actions: dict[str, AgentAction] | None = None
|
||||
) -> dict[str, AgentAction]:
|
||||
if actions is None:
|
||||
actions = {}
|
||||
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
if hasattr(self, "client"):
|
||||
@@ -230,10 +242,6 @@ class Agent(ABC):
|
||||
return self.client.name
|
||||
return None
|
||||
|
||||
@property
|
||||
def verbose_name(self):
|
||||
return self.agent_type.capitalize()
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
if not getattr(self.client, "enabled", True):
|
||||
@@ -315,67 +323,68 @@ class Agent(ABC):
|
||||
return {}
|
||||
|
||||
return {k: v.model_dump() for k, v in self.actions.items()}
|
||||
|
||||
|
||||
# scene state
|
||||
|
||||
|
||||
|
||||
def context_fingerpint(self, extra: list[str] = []) -> str | None:
|
||||
active_agent_context = active_agent.get()
|
||||
|
||||
|
||||
if not active_agent_context:
|
||||
return None
|
||||
|
||||
|
||||
if self.scene.history:
|
||||
fingerprint = f"{self.scene.history[-1].fingerprint}-{active_agent_context.first.fingerprint}"
|
||||
else:
|
||||
fingerprint = f"START-{active_agent_context.first.fingerprint}"
|
||||
|
||||
|
||||
for extra_key in extra:
|
||||
fingerprint += f"-{hash(extra_key)}"
|
||||
|
||||
|
||||
return fingerprint
|
||||
|
||||
|
||||
def get_scene_state(self, key:str, default=None):
|
||||
|
||||
def get_scene_state(self, key: str, default=None):
|
||||
agent_state = self.scene.agent_state.get(self.agent_type, {})
|
||||
return agent_state.get(key, default)
|
||||
|
||||
|
||||
def set_scene_states(self, **kwargs):
|
||||
agent_state = self.scene.agent_state.get(self.agent_type, {})
|
||||
for key, value in kwargs.items():
|
||||
agent_state[key] = value
|
||||
self.scene.agent_state[self.agent_type] = agent_state
|
||||
|
||||
|
||||
def dump_scene_state(self):
|
||||
return self.scene.agent_state.get(self.agent_type, {})
|
||||
|
||||
|
||||
# active agent context state
|
||||
|
||||
def get_context_state(self, key:str, default=None):
|
||||
|
||||
def get_context_state(self, key: str, default=None):
|
||||
key = f"{self.agent_type}__{key}"
|
||||
try:
|
||||
return active_agent.get().state.get(key, default)
|
||||
except AttributeError:
|
||||
log.warning("get_context_state error", agent=self.agent_type, key=key)
|
||||
return default
|
||||
|
||||
|
||||
def set_context_states(self, **kwargs):
|
||||
try:
|
||||
|
||||
items = {f"{self.agent_type}__{k}": v for k, v in kwargs.items()}
|
||||
active_agent.get().state.update(items)
|
||||
log.debug("set_context_states", agent=self.agent_type, state=active_agent.get().state)
|
||||
log.debug(
|
||||
"set_context_states",
|
||||
agent=self.agent_type,
|
||||
state=active_agent.get().state,
|
||||
)
|
||||
except AttributeError:
|
||||
log.error("set_context_states error", agent=self.agent_type, kwargs=kwargs)
|
||||
|
||||
|
||||
def dump_context_state(self):
|
||||
try:
|
||||
return active_agent.get().state
|
||||
except AttributeError:
|
||||
return {}
|
||||
|
||||
|
||||
###
|
||||
|
||||
|
||||
async def _handle_ready_check(self, fut: asyncio.Future):
|
||||
callback_failure = getattr(self, "on_ready_check_failure", None)
|
||||
if fut.cancelled():
|
||||
@@ -425,41 +434,50 @@ class Agent(ABC):
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
for config_key, config in action.config.items():
|
||||
for config_key, _config in action.config.items():
|
||||
try:
|
||||
config.value = (
|
||||
_config.value = (
|
||||
kwargs.get("actions", {})
|
||||
.get(action_key, {})
|
||||
.get("config", {})
|
||||
.get(config_key, {})
|
||||
.get("value", config.value)
|
||||
.get("value", _config.value)
|
||||
)
|
||||
if config.value_migration and callable(config.value_migration):
|
||||
config.value = config.value_migration(config.value)
|
||||
if _config.value_migration and callable(_config.value_migration):
|
||||
_config.value = _config.value_migration(_config.value)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
async def save_config(self, app_config: config.Config | None = None):
|
||||
"""
|
||||
Saves the agent config to the config file.
|
||||
|
||||
|
||||
If no config object is provided, the config is loaded from the config file.
|
||||
"""
|
||||
|
||||
|
||||
if not app_config:
|
||||
app_config:config.Config = config.load_config(as_model=True)
|
||||
|
||||
app_config: config.Config = config.load_config(as_model=True)
|
||||
|
||||
app_config.agents[self.agent_type] = config.Agent(
|
||||
name=self.agent_type,
|
||||
client=self.client.name if self.client else None,
|
||||
enabled=self.enabled,
|
||||
actions={action_key: config.AgentAction(
|
||||
enabled=action.enabled,
|
||||
config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()}
|
||||
) for action_key, action in self.actions.items()}
|
||||
actions={
|
||||
action_key: config.AgentAction(
|
||||
enabled=action.enabled,
|
||||
config={
|
||||
config_key: config.AgentActionConfig(value=config_obj.value)
|
||||
for config_key, config_obj in action.config.items()
|
||||
},
|
||||
)
|
||||
for action_key, action in self.actions.items()
|
||||
},
|
||||
)
|
||||
log.debug(
|
||||
"saving agent config",
|
||||
agent=self.agent_type,
|
||||
config=app_config.agents[self.agent_type],
|
||||
)
|
||||
log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type])
|
||||
config.save_config(app_config)
|
||||
|
||||
async def on_game_loop_start(self, event: GameLoopStartEvent):
|
||||
@@ -474,19 +492,19 @@ class Agent(ABC):
|
||||
if not action.config:
|
||||
continue
|
||||
|
||||
for _, config in action.config.items():
|
||||
if config.scope == "scene":
|
||||
for _, _config in action.config.items():
|
||||
if _config.scope == "scene":
|
||||
# if default_value is None, just use the `type` of the current
|
||||
# value
|
||||
if config.default_value is None:
|
||||
default_value = type(config.value)()
|
||||
if _config.default_value is None:
|
||||
default_value = type(_config.value)()
|
||||
else:
|
||||
default_value = config.default_value
|
||||
default_value = _config.default_value
|
||||
|
||||
log.debug(
|
||||
"resetting config", config=config, default_value=default_value
|
||||
"resetting config", config=_config, default_value=default_value
|
||||
)
|
||||
config.value = default_value
|
||||
_config.value = default_value
|
||||
|
||||
await self.emit_status()
|
||||
|
||||
@@ -518,7 +536,9 @@ class Agent(ABC):
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def _handle_background_processing(self, fut: asyncio.Future, error_handler = None):
|
||||
async def _handle_background_processing(
|
||||
self, fut: asyncio.Future, error_handler=None
|
||||
):
|
||||
try:
|
||||
if fut.cancelled():
|
||||
return
|
||||
@@ -541,7 +561,7 @@ class Agent(ABC):
|
||||
self.processing_bg -= 1
|
||||
await self.emit_status()
|
||||
|
||||
async def set_background_processing(self, task: asyncio.Task, error_handler = None):
|
||||
async def set_background_processing(self, task: asyncio.Task, error_handler=None):
|
||||
log.info("set_background_processing", agent=self.agent_type)
|
||||
if not hasattr(self, "processing_bg"):
|
||||
self.processing_bg = 0
|
||||
@@ -550,7 +570,9 @@ class Agent(ABC):
|
||||
|
||||
await self.emit_status()
|
||||
task.add_done_callback(
|
||||
lambda fut: asyncio.create_task(self._handle_background_processing(fut, error_handler))
|
||||
lambda fut: asyncio.create_task(
|
||||
self._handle_background_processing(fut, error_handler)
|
||||
)
|
||||
)
|
||||
|
||||
def connect(self, scene):
|
||||
@@ -623,7 +645,6 @@ class Agent(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@set_processing
|
||||
async def delegate(self, fn: Callable, *args, **kwargs):
|
||||
"""
|
||||
@@ -631,20 +652,22 @@ class Agent(ABC):
|
||||
by the agent.
|
||||
"""
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
async def emit_message(self, header:str, message:str | list[dict], meta: dict = None, **data):
|
||||
|
||||
async def emit_message(
|
||||
self, header: str, message: str | list[dict], meta: dict = None, **data
|
||||
):
|
||||
if not data:
|
||||
data = {}
|
||||
|
||||
|
||||
if not meta:
|
||||
meta = {}
|
||||
|
||||
|
||||
if "uuid" not in data:
|
||||
data["uuid"] = str(uuid.uuid4())
|
||||
|
||||
|
||||
if "agent" not in data:
|
||||
data["agent"] = self.agent_type
|
||||
|
||||
|
||||
data["header"] = header
|
||||
emit(
|
||||
"agent_message",
|
||||
@@ -653,17 +676,22 @@ class Agent(ABC):
|
||||
meta=meta,
|
||||
websocket_passthrough=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentEmission:
|
||||
agent: Agent
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentTemplateEmission(AgentEmission):
|
||||
template_vars: dict = dataclasses.field(default_factory=dict)
|
||||
response: str = None
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RagBuildSubInstructionEmission(AgentEmission):
|
||||
sub_instruction: str | None = None
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import contextvars
|
||||
import uuid
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import Callable
|
||||
|
||||
import pydantic
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
__all__ = [
|
||||
"active_agent",
|
||||
]
|
||||
@@ -25,7 +21,6 @@ class ActiveAgentContext(pydantic.BaseModel):
|
||||
state: dict = pydantic.Field(default_factory=dict)
|
||||
state_params: dict = pydantic.Field(default_factory=dict)
|
||||
previous: "ActiveAgentContext" = None
|
||||
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -35,29 +30,30 @@ class ActiveAgentContext(pydantic.BaseModel):
|
||||
return self.previous.first if self.previous else self
|
||||
|
||||
@property
|
||||
def action(self):
|
||||
def action(self):
|
||||
name = self.fn.__name__
|
||||
if name == "delegate":
|
||||
return self.fn_args[0].__name__
|
||||
return name
|
||||
|
||||
|
||||
@property
|
||||
def fingerprint(self) -> int:
|
||||
if hasattr(self, "_fingerprint"):
|
||||
return self._fingerprint
|
||||
self._fingerprint = hash(frozenset(self.state_params.items()))
|
||||
return self._fingerprint
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.agent.verbose_name}.{self.action}"
|
||||
|
||||
|
||||
|
||||
class ActiveAgent:
|
||||
def __init__(self, agent, fn, args=None, kwargs=None):
|
||||
self.agent = ActiveAgentContext(agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {})
|
||||
self.agent = ActiveAgentContext(
|
||||
agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {}
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
previous_agent = active_agent.get()
|
||||
|
||||
if previous_agent:
|
||||
@@ -70,7 +66,7 @@ class ActiveAgent:
|
||||
self.agent.agent_stack_uid = str(uuid.uuid4())
|
||||
|
||||
self.token = active_agent.set(self.agent)
|
||||
|
||||
|
||||
return self.agent
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import random
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@@ -10,14 +9,12 @@ import structlog
|
||||
|
||||
import talemate.client as client
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
import talemate.util as util
|
||||
from talemate.client.context import (
|
||||
client_context_attribute,
|
||||
set_client_context_attribute,
|
||||
set_conversation_context_attribute,
|
||||
)
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.exceptions import LLMAccuracyError
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import CharacterMessage, DirectorMessage
|
||||
@@ -37,7 +34,7 @@ from talemate.agents.memory.rag import MemoryRAGMixin
|
||||
from talemate.agents.context import active_agent
|
||||
|
||||
from .websocket_handler import ConversationWebsocketHandler
|
||||
import talemate.agents.conversation.nodes
|
||||
import talemate.agents.conversation.nodes # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Character
|
||||
@@ -50,21 +47,20 @@ class ConversationAgentEmission(AgentEmission):
|
||||
actor: Actor
|
||||
character: Character
|
||||
response: str
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.conversation.before_generate",
|
||||
"agent.conversation.before_generate",
|
||||
"agent.conversation.inject_instructions",
|
||||
"agent.conversation.generated"
|
||||
"agent.conversation.generated",
|
||||
)
|
||||
|
||||
|
||||
@register()
|
||||
class ConversationAgent(
|
||||
MemoryRAGMixin,
|
||||
Agent
|
||||
):
|
||||
class ConversationAgent(MemoryRAGMixin, Agent):
|
||||
"""
|
||||
An agent that can be used to have a conversation with the AI
|
||||
|
||||
@@ -135,8 +131,6 @@ class ConversationAgent(
|
||||
max=20,
|
||||
step=1,
|
||||
),
|
||||
|
||||
|
||||
},
|
||||
),
|
||||
"auto_break_repetition": AgentAction(
|
||||
@@ -159,7 +153,7 @@ class ConversationAgent(
|
||||
description="Use the writing style selected in the scene settings",
|
||||
value=True,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
}
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
@@ -205,7 +199,6 @@ class ConversationAgent(
|
||||
|
||||
@property
|
||||
def agent_details(self) -> dict:
|
||||
|
||||
details = {
|
||||
"client": AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
@@ -231,22 +224,24 @@ class ConversationAgent(
|
||||
|
||||
@property
|
||||
def generation_settings_actor_instructions_offset(self):
|
||||
return self.actions["generation_override"].config["actor_instructions_offset"].value
|
||||
|
||||
return (
|
||||
self.actions["generation_override"]
|
||||
.config["actor_instructions_offset"]
|
||||
.value
|
||||
)
|
||||
|
||||
@property
|
||||
def generation_settings_response_length(self):
|
||||
return self.actions["generation_override"].config["length"].value
|
||||
|
||||
|
||||
@property
|
||||
def generation_settings_override_enabled(self):
|
||||
return self.actions["generation_override"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def content_use_writing_style(self) -> bool:
|
||||
return self.actions["content"].config["use_writing_style"].value
|
||||
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
|
||||
@@ -276,7 +271,7 @@ class ConversationAgent(
|
||||
main_character = scene.main_character.character
|
||||
|
||||
character_names = [c.name for c in scene.characters]
|
||||
|
||||
|
||||
if main_character:
|
||||
try:
|
||||
character_names.remove(main_character.name)
|
||||
@@ -296,21 +291,22 @@ class ConversationAgent(
|
||||
director_message = isinstance(scene_and_dialogue[-1], DirectorMessage)
|
||||
except IndexError:
|
||||
director_message = False
|
||||
|
||||
|
||||
|
||||
inject_instructions_emission = ConversationAgentEmission(
|
||||
agent=self,
|
||||
response="",
|
||||
actor=None,
|
||||
character=character,
|
||||
response="",
|
||||
actor=None,
|
||||
character=character,
|
||||
)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.conversation.inject_instructions"
|
||||
).send(inject_instructions_emission)
|
||||
|
||||
|
||||
agent_context = active_agent.get()
|
||||
agent_context.state["dynamic_instructions"] = inject_instructions_emission.dynamic_instructions
|
||||
|
||||
agent_context.state["dynamic_instructions"] = (
|
||||
inject_instructions_emission.dynamic_instructions
|
||||
)
|
||||
|
||||
conversation_format = self.conversation_format
|
||||
prompt = Prompt.get(
|
||||
f"conversation.dialogue-{conversation_format}",
|
||||
@@ -319,26 +315,30 @@ class ConversationAgent(
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene_and_dialogue_budget": scene_and_dialogue_budget,
|
||||
"scene_and_dialogue": scene_and_dialogue,
|
||||
"memory": None, # DEPRECATED VARIABLE
|
||||
"memory": None, # DEPRECATED VARIABLE
|
||||
"characters": list(scene.get_characters()),
|
||||
"main_character": main_character,
|
||||
"formatted_names": formatted_names,
|
||||
"talking_character": character,
|
||||
"partial_message": char_message,
|
||||
"director_message": director_message,
|
||||
"extra_instructions": self.generation_settings_task_instructions, #backward compatibility
|
||||
"extra_instructions": self.generation_settings_task_instructions, # backward compatibility
|
||||
"task_instructions": self.generation_settings_task_instructions,
|
||||
"actor_instructions": self.generation_settings_actor_instructions,
|
||||
"actor_instructions_offset": self.generation_settings_actor_instructions_offset,
|
||||
"direct_instruction": instruction,
|
||||
"decensor": self.client.decensor_enabled,
|
||||
"response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None,
|
||||
"response_length": self.generation_settings_response_length
|
||||
if self.generation_settings_override_enabled
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
return str(prompt)
|
||||
|
||||
async def build_prompt(self, character, char_message: str = "", instruction:str = None):
|
||||
async def build_prompt(
|
||||
self, character, char_message: str = "", instruction: str = None
|
||||
):
|
||||
fn = self.build_prompt_default
|
||||
|
||||
return await fn(character, char_message=char_message, instruction=instruction)
|
||||
@@ -376,12 +376,12 @@ class ConversationAgent(
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
|
||||
@set_processing
|
||||
@store_context_state('instruction')
|
||||
@store_context_state("instruction")
|
||||
async def converse(
|
||||
self,
|
||||
self,
|
||||
actor,
|
||||
instruction:str = None,
|
||||
emit_signals:bool = True,
|
||||
instruction: str = None,
|
||||
emit_signals: bool = True,
|
||||
) -> list[CharacterMessage]:
|
||||
"""
|
||||
Have a conversation with the AI
|
||||
@@ -398,7 +398,9 @@ class ConversationAgent(
|
||||
|
||||
self.set_generation_overrides()
|
||||
|
||||
result = await self.client.send_prompt(await self.build_prompt(character, instruction=instruction))
|
||||
result = await self.client.send_prompt(
|
||||
await self.build_prompt(character, instruction=instruction)
|
||||
)
|
||||
|
||||
result = self.clean_result(result, character)
|
||||
|
||||
@@ -454,7 +456,7 @@ class ConversationAgent(
|
||||
# movie script format
|
||||
# {uppercase character name}
|
||||
# {dialogue}
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", f"")
|
||||
total_result = total_result.replace(f"{character.name.upper()}\n", "")
|
||||
|
||||
# chat format
|
||||
# {character name}: {dialogue}
|
||||
@@ -464,7 +466,7 @@ class ConversationAgent(
|
||||
total_result = util.clean_dialogue(total_result, main_name=character.name)
|
||||
|
||||
# Check if total_result starts with character name, if not, prepend it
|
||||
if not total_result.startswith(character.name+":"):
|
||||
if not total_result.startswith(character.name + ":"):
|
||||
total_result = f"{character.name}: {total_result}"
|
||||
|
||||
total_result = total_result.strip()
|
||||
@@ -481,11 +483,11 @@ class ConversationAgent(
|
||||
|
||||
log.debug("conversation agent", response=response)
|
||||
emission = ConversationAgentEmission(
|
||||
agent=self,
|
||||
actor=actor,
|
||||
character=character,
|
||||
agent=self,
|
||||
actor=actor,
|
||||
character=character,
|
||||
response=response,
|
||||
)
|
||||
)
|
||||
if emit_signals:
|
||||
await talemate.emit.async_signals.get("agent.conversation.generated").send(
|
||||
emission
|
||||
|
||||
@@ -1,74 +1,75 @@
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, UNRESOLVED
|
||||
from talemate.game.engine.nodes.core import GraphState, UNRESOLVED
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentNode, AgentSettingsNode
|
||||
from talemate.context import active_scene
|
||||
from talemate.client.context import ConversationContext, ClientContext
|
||||
import talemate.events as events
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene, Character
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.conversation")
|
||||
|
||||
|
||||
@register("agents/conversation/Settings")
|
||||
class ConversationSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render conversation agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "conversation"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "conversation"
|
||||
|
||||
def __init__(self, title="Conversation Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
@register("agents/conversation/Generate")
|
||||
class GenerateConversation(AgentNode):
|
||||
"""
|
||||
Generate a conversation between two characters
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "conversation"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "conversation"
|
||||
|
||||
def __init__(self, title="Generate Conversation", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.add_input("instruction", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.set_property("trigger_conversation_generated", True)
|
||||
|
||||
|
||||
self.add_output("generated", socket_type="str")
|
||||
self.add_output("message", socket_type="message_object")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character:"Character" = self.get_input_value("character")
|
||||
scene:"Scene" = active_scene.get()
|
||||
character: "Character" = self.get_input_value("character")
|
||||
scene: "Scene" = active_scene.get()
|
||||
instruction = self.get_input_value("instruction")
|
||||
trigger_conversation_generated = self.get_property("trigger_conversation_generated")
|
||||
|
||||
trigger_conversation_generated = self.get_property(
|
||||
"trigger_conversation_generated"
|
||||
)
|
||||
|
||||
other_characters = [c.name for c in scene.characters if c != character]
|
||||
|
||||
|
||||
conversation_context = ConversationContext(
|
||||
talking_character=character.name,
|
||||
other_characters=other_characters,
|
||||
)
|
||||
|
||||
|
||||
if instruction == UNRESOLVED:
|
||||
instruction = None
|
||||
|
||||
|
||||
with ClientContext(conversation=conversation_context):
|
||||
messages = await self.agent.converse(
|
||||
character.actor,
|
||||
character.actor,
|
||||
instruction=instruction,
|
||||
emit_signals=trigger_conversation_generated,
|
||||
)
|
||||
|
||||
|
||||
message = messages[0]
|
||||
|
||||
self.set_output_values({
|
||||
"generated": message.message,
|
||||
"message": message
|
||||
})
|
||||
|
||||
self.set_output_values({"generated": message.message, "message": message})
|
||||
|
||||
@@ -18,23 +18,25 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.conversation")
|
||||
|
||||
|
||||
class RequestActorActionPayload(pydantic.BaseModel):
|
||||
character:str = ""
|
||||
instructions:str = ""
|
||||
emit_signals:bool = True
|
||||
instructions_through_director:bool = True
|
||||
character: str = ""
|
||||
instructions: str = ""
|
||||
emit_signals: bool = True
|
||||
instructions_through_director: bool = True
|
||||
|
||||
|
||||
class ConversationWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles narrator actions
|
||||
"""
|
||||
|
||||
|
||||
router = "conversation"
|
||||
|
||||
|
||||
@property
|
||||
def agent(self) -> "ConversationAgent":
|
||||
return get_agent("conversation")
|
||||
|
||||
|
||||
@set_loading("Generating actor action", cancellable=True, as_async=True)
|
||||
async def handle_request_actor_action(self, data: dict):
|
||||
"""
|
||||
@@ -43,38 +45,37 @@ class ConversationWebsocketHandler(Plugin):
|
||||
payload = RequestActorActionPayload(**data)
|
||||
character = None
|
||||
actor = None
|
||||
|
||||
|
||||
if payload.character:
|
||||
character = self.scene.get_character(payload.character)
|
||||
actor = character.actor
|
||||
else:
|
||||
actor = random.choice(list(self.scene.get_npc_characters())).actor
|
||||
|
||||
|
||||
if not actor:
|
||||
log.error("handle_request_actor_action: No actor found")
|
||||
return
|
||||
|
||||
|
||||
character = actor.character
|
||||
|
||||
|
||||
if payload.instructions_through_director:
|
||||
director_message = DirectorMessage(
|
||||
payload.instructions,
|
||||
source="player",
|
||||
meta={"character": character.name}
|
||||
meta={"character": character.name},
|
||||
)
|
||||
emit("director", message=director_message, character=character)
|
||||
self.scene.push_history(director_message)
|
||||
generated_messages = await self.agent.converse(
|
||||
actor,
|
||||
emit_signals=payload.emit_signals
|
||||
actor, emit_signals=payload.emit_signals
|
||||
)
|
||||
else:
|
||||
generated_messages = await self.agent.converse(
|
||||
actor,
|
||||
actor,
|
||||
instruction=payload.instructions,
|
||||
emit_signals=payload.emit_signals
|
||||
emit_signals=payload.emit_signals,
|
||||
)
|
||||
|
||||
|
||||
for message in generated_messages:
|
||||
self.scene.push_history(message)
|
||||
emit("character", message=message, character=character)
|
||||
emit("character", message=message, character=character)
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import talemate.client as client
|
||||
from talemate.agents.base import Agent, set_processing
|
||||
from talemate.agents.registry import register
|
||||
from talemate.agents.memory.rag import MemoryRAGMixin
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
from .assistant import AssistantMixin
|
||||
@@ -16,7 +12,8 @@ from .scenario import ScenarioCreatorMixin
|
||||
|
||||
from talemate.agents.base import AgentAction
|
||||
|
||||
import talemate.agents.creator.nodes
|
||||
import talemate.agents.creator.nodes # noqa: F401
|
||||
|
||||
|
||||
@register()
|
||||
class CreatorAgent(
|
||||
@@ -51,7 +48,7 @@ class CreatorAgent(
|
||||
@set_processing
|
||||
async def generate_title(self, text: str):
|
||||
title = await Prompt.request(
|
||||
f"creator.generate-title",
|
||||
"creator.generate-title",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
|
||||
@@ -40,6 +40,7 @@ async_signals.register(
|
||||
"agent.creator.autocomplete.after",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ContextualGenerateEmission(AgentTemplateEmission):
|
||||
"""
|
||||
@@ -48,15 +49,16 @@ class ContextualGenerateEmission(AgentTemplateEmission):
|
||||
|
||||
content_generation_context: "ContentGenerationContext | None" = None
|
||||
character: "Character | None" = None
|
||||
|
||||
|
||||
@property
|
||||
def context_type(self) -> str:
|
||||
return self.content_generation_context.computed_context[0]
|
||||
|
||||
|
||||
@property
|
||||
def context_name(self) -> str:
|
||||
return self.content_generation_context.computed_context[1]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutocompleteEmission(AgentTemplateEmission):
|
||||
"""
|
||||
@@ -67,6 +69,7 @@ class AutocompleteEmission(AgentTemplateEmission):
|
||||
type: str = ""
|
||||
character: "Character | None" = None
|
||||
|
||||
|
||||
class ContentGenerationContext(pydantic.BaseModel):
|
||||
"""
|
||||
A context for generating content.
|
||||
@@ -104,7 +107,6 @@ class ContentGenerationContext(pydantic.BaseModel):
|
||||
|
||||
@property
|
||||
def spice(self) -> str:
|
||||
|
||||
spice_level = self.generation_options.spice_level
|
||||
|
||||
if self.template and not getattr(self.template, "supports_spice", False):
|
||||
@@ -148,7 +150,6 @@ class ContentGenerationContext(pydantic.BaseModel):
|
||||
|
||||
@property
|
||||
def style(self):
|
||||
|
||||
if self.template and not getattr(self.template, "supports_style", False):
|
||||
# template supplied that doesn't support style
|
||||
return ""
|
||||
@@ -165,11 +166,12 @@ class ContentGenerationContext(pydantic.BaseModel):
|
||||
def get_state(self, key: str) -> str | int | float | bool | None:
|
||||
return self.state.get(key)
|
||||
|
||||
|
||||
class AssistantMixin:
|
||||
"""
|
||||
Creator mixin that allows quick contextual generation of content.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["autocomplete"] = AgentAction(
|
||||
@@ -198,15 +200,15 @@ class AssistantMixin:
|
||||
max=256,
|
||||
step=16,
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# property helpers
|
||||
|
||||
|
||||
@property
|
||||
def autocomplete_dialogue_suggestion_length(self):
|
||||
return self.actions["autocomplete"].config["dialogue_suggestion_length"].value
|
||||
|
||||
|
||||
@property
|
||||
def autocomplete_narrative_suggestion_length(self):
|
||||
return self.actions["autocomplete"].config["narrative_suggestion_length"].value
|
||||
@@ -255,7 +257,7 @@ class AssistantMixin:
|
||||
history_aware=history_aware,
|
||||
information=information,
|
||||
)
|
||||
|
||||
|
||||
for key, value in kwargs.items():
|
||||
generation_context.set_state(key, value)
|
||||
|
||||
@@ -279,9 +281,13 @@ class AssistantMixin:
|
||||
f"Contextual generate: {context_typ} - {context_name}",
|
||||
generation_context=generation_context,
|
||||
)
|
||||
|
||||
character = self.scene.get_character(generation_context.character) if generation_context.character else None
|
||||
|
||||
|
||||
character = (
|
||||
self.scene.get_character(generation_context.character)
|
||||
if generation_context.character
|
||||
else None
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -295,27 +301,29 @@ class AssistantMixin:
|
||||
"character": character,
|
||||
"template": generation_context.template,
|
||||
}
|
||||
|
||||
|
||||
emission = ContextualGenerateEmission(
|
||||
agent=self,
|
||||
content_generation_context=generation_context,
|
||||
character=character,
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.before").send(emission)
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.before").send(
|
||||
emission
|
||||
)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
content = await Prompt.request(
|
||||
f"creator.contextual-generate",
|
||||
"creator.contextual-generate",
|
||||
self.client,
|
||||
kind,
|
||||
vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
emission.response = content
|
||||
|
||||
|
||||
if not generation_context.partial:
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
@@ -329,22 +337,29 @@ class AssistantMixin:
|
||||
if not content.startswith(generation_context.character + ":"):
|
||||
content = generation_context.character + ": " + content
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
character = self.scene.get_character(generation_context.character)
|
||||
|
||||
if not character:
|
||||
log.warning("Character not found", character=generation_context.character)
|
||||
return content
|
||||
|
||||
emission.response = await editor.cleanup_character_message(content, character)
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
|
||||
return emission.response
|
||||
|
||||
emission.response = content.strip().strip("*").strip()
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(emission)
|
||||
return emission.response
|
||||
|
||||
character = self.scene.get_character(generation_context.character)
|
||||
|
||||
if not character:
|
||||
log.warning(
|
||||
"Character not found", character=generation_context.character
|
||||
)
|
||||
return content
|
||||
|
||||
emission.response = await editor.cleanup_character_message(
|
||||
content, character
|
||||
)
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(
|
||||
emission
|
||||
)
|
||||
return emission.response
|
||||
|
||||
emission.response = content.strip().strip("*").strip()
|
||||
|
||||
await async_signals.get("agent.creator.contextual_generate.after").send(
|
||||
emission
|
||||
)
|
||||
return emission.response
|
||||
|
||||
@set_processing
|
||||
async def generate_character_attribute(
|
||||
@@ -357,10 +372,10 @@ class AssistantMixin:
|
||||
) -> str:
|
||||
"""
|
||||
Wrapper for contextual_generate that generates a character attribute.
|
||||
"""
|
||||
"""
|
||||
if not generation_options:
|
||||
generation_options = GenerationOptions()
|
||||
|
||||
|
||||
return await self.contextual_generate_from_args(
|
||||
context=f"character attribute:{attribute_name}",
|
||||
character=character.name,
|
||||
@@ -368,7 +383,7 @@ class AssistantMixin:
|
||||
original=original,
|
||||
**generation_options.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_character_detail(
|
||||
self,
|
||||
@@ -381,11 +396,11 @@ class AssistantMixin:
|
||||
) -> str:
|
||||
"""
|
||||
Wrapper for contextual_generate that generates a character detail.
|
||||
"""
|
||||
"""
|
||||
|
||||
if not generation_options:
|
||||
generation_options = GenerationOptions()
|
||||
|
||||
|
||||
return await self.contextual_generate_from_args(
|
||||
context=f"character detail:{detail_name}",
|
||||
character=character.name,
|
||||
@@ -394,7 +409,7 @@ class AssistantMixin:
|
||||
length=length,
|
||||
**generation_options.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_thematic_list(
|
||||
self,
|
||||
@@ -408,11 +423,11 @@ class AssistantMixin:
|
||||
"""
|
||||
if not generation_options:
|
||||
generation_options = GenerationOptions()
|
||||
|
||||
|
||||
i = 0
|
||||
|
||||
|
||||
result = []
|
||||
|
||||
|
||||
while i < iterations:
|
||||
i += 1
|
||||
_result = await self.contextual_generate_from_args(
|
||||
@@ -420,14 +435,14 @@ class AssistantMixin:
|
||||
instructions=instructions,
|
||||
length=length,
|
||||
original="\n".join(result) if result else None,
|
||||
extend=i>1,
|
||||
extend=i > 1,
|
||||
**generation_options.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
_result = json.loads(_result)
|
||||
|
||||
|
||||
result = list(set(result + _result))
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@set_processing
|
||||
@@ -443,34 +458,34 @@ class AssistantMixin:
|
||||
"""
|
||||
if not response_length:
|
||||
response_length = self.autocomplete_dialogue_suggestion_length
|
||||
|
||||
|
||||
# continuing recent character message
|
||||
non_anchor, anchor = util.split_anchor_text(input, 10)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_anchor",
|
||||
anchor=anchor,
|
||||
non_anchor=non_anchor,
|
||||
input=input
|
||||
"autocomplete_anchor", anchor=anchor, non_anchor=non_anchor, input=input
|
||||
)
|
||||
|
||||
continuing_message = False
|
||||
message = None
|
||||
|
||||
|
||||
try:
|
||||
message = self.scene.history[-1]
|
||||
if isinstance(message, CharacterMessage) and message.character_name == character.name:
|
||||
if (
|
||||
isinstance(message, CharacterMessage)
|
||||
and message.character_name == character.name
|
||||
):
|
||||
continuing_message = input.strip() == message.without_name.strip()
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
if input.strip().endswith('"'):
|
||||
prefix = ' *'
|
||||
elif input.strip().endswith('*'):
|
||||
prefix = " *"
|
||||
elif input.strip().endswith("*"):
|
||||
prefix = ' "'
|
||||
else:
|
||||
prefix = ''
|
||||
|
||||
prefix = ""
|
||||
|
||||
template_vars = {
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -484,7 +499,7 @@ class AssistantMixin:
|
||||
"non_anchor": non_anchor,
|
||||
"prefix": prefix,
|
||||
}
|
||||
|
||||
|
||||
emission = AutocompleteEmission(
|
||||
agent=self,
|
||||
input=input,
|
||||
@@ -492,13 +507,13 @@ class AssistantMixin:
|
||||
character=character,
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.before").send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-dialogue",
|
||||
"creator.autocomplete-dialogue",
|
||||
self.client,
|
||||
f"create_{response_length}",
|
||||
vars=template_vars,
|
||||
@@ -506,21 +521,22 @@ class AssistantMixin:
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
response = response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
|
||||
|
||||
|
||||
response = (
|
||||
response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "")
|
||||
)
|
||||
|
||||
if prefix:
|
||||
response = prefix + response
|
||||
|
||||
|
||||
emission.response = response
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.after").send(emission)
|
||||
|
||||
if not response:
|
||||
if emit_signal:
|
||||
emit("autocomplete_suggestion", "")
|
||||
return ""
|
||||
|
||||
|
||||
response = util.strip_partial_sentences(response).replace("*", "")
|
||||
|
||||
if response.startswith(input):
|
||||
@@ -550,14 +566,14 @@ class AssistantMixin:
|
||||
|
||||
# Split the input text into non-anchor and anchor parts
|
||||
non_anchor, anchor = util.split_anchor_text(input, 10)
|
||||
|
||||
|
||||
self.scene.log.debug(
|
||||
"autocomplete_narrative_anchor",
|
||||
"autocomplete_narrative_anchor",
|
||||
anchor=anchor,
|
||||
non_anchor=non_anchor,
|
||||
input=input
|
||||
input=input,
|
||||
)
|
||||
|
||||
|
||||
template_vars = {
|
||||
"scene": self.scene,
|
||||
"max_tokens": self.client.max_token_length,
|
||||
@@ -567,20 +583,20 @@ class AssistantMixin:
|
||||
"anchor": anchor,
|
||||
"non_anchor": non_anchor,
|
||||
}
|
||||
|
||||
|
||||
emission = AutocompleteEmission(
|
||||
agent=self,
|
||||
input=input,
|
||||
type="narrative",
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.before").send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
response = await Prompt.request(
|
||||
f"creator.autocomplete-narrative",
|
||||
"creator.autocomplete-narrative",
|
||||
self.client,
|
||||
f"create_{response_length}",
|
||||
vars=template_vars,
|
||||
@@ -593,7 +609,7 @@ class AssistantMixin:
|
||||
response = response[len(input) :]
|
||||
|
||||
emission.response = response
|
||||
|
||||
|
||||
await async_signals.get("agent.creator.autocomplete.after").send(emission)
|
||||
|
||||
self.scene.log.debug(
|
||||
@@ -614,78 +630,75 @@ class AssistantMixin:
|
||||
"""
|
||||
Allows to fork a new scene from a specific message
|
||||
in the current scene.
|
||||
|
||||
|
||||
All content after the message will be removed and the
|
||||
context database will be re imported ensuring a clean state.
|
||||
|
||||
|
||||
All state reinforcements will be reset to their most recent
|
||||
state before the message.
|
||||
"""
|
||||
|
||||
|
||||
emit("status", "Creating scene fork ...", status="busy")
|
||||
try:
|
||||
if not save_name:
|
||||
# build a save name
|
||||
uuid_str = str(uuid.uuid4())[:8]
|
||||
save_name = f"{uuid_str}-forked"
|
||||
|
||||
log.info(f"Forking scene", message_id=message_id, save_name=save_name)
|
||||
|
||||
|
||||
log.info("Forking scene", message_id=message_id, save_name=save_name)
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
|
||||
|
||||
# does a message with the given id exist?
|
||||
index = self.scene.message_index(message_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Message with id {message_id} not found.")
|
||||
|
||||
|
||||
# truncate scene.history keeping index as the last element
|
||||
self.scene.history = self.scene.history[:index + 1]
|
||||
|
||||
self.scene.history = self.scene.history[: index + 1]
|
||||
|
||||
# truncate scene.archived_history keeping the element where `end` is < `index`
|
||||
# as the last element
|
||||
self.scene.archived_history = [
|
||||
x for x in self.scene.archived_history if "end" not in x or x["end"] < index
|
||||
x
|
||||
for x in self.scene.archived_history
|
||||
if "end" not in x or x["end"] < index
|
||||
]
|
||||
|
||||
|
||||
# the same needs to be done for layered history
|
||||
# where each layer is truncated based on what's left in the previous layer
|
||||
# using similar logic as above (checking `end` vs `index`)
|
||||
# layer 0 checks archived_history
|
||||
|
||||
|
||||
new_layered_history = []
|
||||
for layer_number, layer in enumerate(self.scene.layered_history):
|
||||
|
||||
if layer_number == 0:
|
||||
index = len(self.scene.archived_history) - 1
|
||||
else:
|
||||
index = len(new_layered_history[layer_number - 1]) - 1
|
||||
|
||||
new_layer = [
|
||||
x for x in layer if x["end"] < index
|
||||
]
|
||||
|
||||
new_layer = [x for x in layer if x["end"] < index]
|
||||
new_layered_history.append(new_layer)
|
||||
|
||||
|
||||
self.scene.layered_history = new_layered_history
|
||||
|
||||
# save the scene
|
||||
await self.scene.save(copy_name=save_name)
|
||||
|
||||
log.info(f"Scene forked", save_name=save_name)
|
||||
|
||||
|
||||
log.info("Scene forked", save_name=save_name)
|
||||
|
||||
# re-emit history
|
||||
await self.scene.emit_history()
|
||||
|
||||
emit("status", f"Updating world state ...", status="busy")
|
||||
|
||||
emit("status", "Updating world state ...", status="busy")
|
||||
|
||||
# reset state reinforcements
|
||||
await world_state.update_reinforcements(force = True, reset= True)
|
||||
|
||||
await world_state.update_reinforcements(force=True, reset=True)
|
||||
|
||||
# update world state
|
||||
await self.scene.world_state.request_update()
|
||||
|
||||
emit("status", f"Scene forked", status="success")
|
||||
except Exception as e:
|
||||
|
||||
emit("status", "Scene forked", status="success")
|
||||
except Exception:
|
||||
log.error("Scene fork failed", exc=traceback.format_exc())
|
||||
emit("status", "Scene fork failed", status="error")
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ import structlog
|
||||
from talemate.agents.base import set_processing
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
import talemate.game.focal as focal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
@@ -18,14 +16,13 @@ DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audien
|
||||
|
||||
|
||||
class CharacterCreatorMixin:
|
||||
|
||||
@set_processing
|
||||
async def determine_content_context_for_character(
|
||||
self,
|
||||
character: Character,
|
||||
):
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
"creator.determine-content-context",
|
||||
self.client,
|
||||
"create_192",
|
||||
vars={
|
||||
@@ -42,7 +39,7 @@ class CharacterCreatorMixin:
|
||||
information: str = "",
|
||||
):
|
||||
instructions = await Prompt.request(
|
||||
f"creator.determine-character-dialogue-instructions",
|
||||
"creator.determine-character-dialogue-instructions",
|
||||
self.client,
|
||||
"create_concise",
|
||||
vars={
|
||||
@@ -63,7 +60,7 @@ class CharacterCreatorMixin:
|
||||
character: Character,
|
||||
):
|
||||
attributes = await Prompt.request(
|
||||
f"creator.determine-character-attributes",
|
||||
"creator.determine-character-attributes",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
@@ -81,7 +78,7 @@ class CharacterCreatorMixin:
|
||||
instructions: str = "",
|
||||
) -> str:
|
||||
name = await Prompt.request(
|
||||
f"creator.determine-character-name",
|
||||
"creator.determine-character-name",
|
||||
self.client,
|
||||
"analyze_freeform_short",
|
||||
vars={
|
||||
@@ -97,14 +94,14 @@ class CharacterCreatorMixin:
|
||||
|
||||
@set_processing
|
||||
async def determine_character_description(
|
||||
self,
|
||||
self,
|
||||
character: Character,
|
||||
text: str = "",
|
||||
instructions: str = "",
|
||||
information: str = "",
|
||||
):
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-character-description",
|
||||
"creator.determine-character-description",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
@@ -125,7 +122,7 @@ class CharacterCreatorMixin:
|
||||
goal_instructions: str,
|
||||
):
|
||||
goals = await Prompt.request(
|
||||
f"creator.determine-character-goals",
|
||||
"creator.determine-character-goals",
|
||||
self.client,
|
||||
"create",
|
||||
vars={
|
||||
@@ -141,4 +138,4 @@ class CharacterCreatorMixin:
|
||||
log.debug("determine_character_goals", goals=goals, character=character)
|
||||
await character.set_detail("goals", goals.strip())
|
||||
|
||||
return goals.strip()
|
||||
return goals.strip()
|
||||
|
||||
@@ -1,145 +1,153 @@
|
||||
import structlog
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from typing import ClassVar
|
||||
from talemate.context import active_scene
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
UNRESOLVED,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.creator")
|
||||
|
||||
|
||||
@register("agents/creator/Settings")
|
||||
class CreatorSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render creator agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
def __init__(self, title="Creator Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
@register("agents/creator/DetermineContentContext")
|
||||
class DetermineContentContext(AgentNode):
|
||||
"""
|
||||
Determines the context for the content creation.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
description = PropertyField(
|
||||
name="description",
|
||||
description="Description of the context",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Determine Content Context", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("description", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.set_property("description", UNRESOLVED)
|
||||
|
||||
|
||||
self.add_output("content_context", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
context = await self.agent.determine_content_context_for_description(
|
||||
self.require_input("description")
|
||||
)
|
||||
self.set_output_values({
|
||||
"content_context": context
|
||||
})
|
||||
|
||||
self.set_output_values({"content_context": context})
|
||||
|
||||
|
||||
@register("agents/creator/DetermineCharacterDescription")
|
||||
class DetermineCharacterDescription(AgentNode):
|
||||
"""
|
||||
Determines the description for a character.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- character: The character to determine the description for
|
||||
- extra_context: Extra context to use in determining the
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- description: The determined description
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
def __init__(self, title="Determine Character Description", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.add_input("extra_context", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.add_output("description", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
|
||||
character = self.require_input("character")
|
||||
extra_context = self.get_input_value("extra_context")
|
||||
|
||||
|
||||
if extra_context is UNRESOLVED:
|
||||
extra_context = ""
|
||||
|
||||
description = await self.agent.determine_character_description(character, extra_context)
|
||||
|
||||
self.set_output_values({
|
||||
"description": description
|
||||
})
|
||||
|
||||
|
||||
description = await self.agent.determine_character_description(
|
||||
character, extra_context
|
||||
)
|
||||
|
||||
self.set_output_values({"description": description})
|
||||
|
||||
|
||||
@register("agents/creator/DetermineCharacterDialogueInstructions")
|
||||
class DetermineCharacterDialogueInstructions(AgentNode):
|
||||
"""
|
||||
Determines the dialogue instructions for a character.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
description="Any additional instructions to use in determining the dialogue instructions",
|
||||
type="text",
|
||||
default=""
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Determine Character Dialogue Instructions", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.add_input("instructions", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.set_property("instructions", "")
|
||||
|
||||
|
||||
self.add_output("dialogue_instructions", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character = self.require_input("character")
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
|
||||
dialogue_instructions = await self.agent.determine_character_dialogue_instructions(character, instructions)
|
||||
|
||||
self.set_output_values({
|
||||
"dialogue_instructions": dialogue_instructions
|
||||
})
|
||||
|
||||
dialogue_instructions = (
|
||||
await self.agent.determine_character_dialogue_instructions(
|
||||
character, instructions
|
||||
)
|
||||
)
|
||||
|
||||
self.set_output_values({"dialogue_instructions": dialogue_instructions})
|
||||
|
||||
|
||||
@register("agents/creator/ContextualGenerate")
|
||||
class ContextualGenerate(AgentNode):
|
||||
"""
|
||||
Generates text based on the given context.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- context_type: The type of context to use in generating the text
|
||||
- context_name: The name of the context to use in generating the text
|
||||
@@ -150,9 +158,9 @@ class ContextualGenerate(AgentNode):
|
||||
- partial: The partial text to use in generating the text
|
||||
- uid: The uid to use in generating the text
|
||||
- generation_options: The generation options to use in generating the text
|
||||
|
||||
|
||||
Properties:
|
||||
|
||||
|
||||
- context_type: The type of context to use in generating the text
|
||||
- context_name: The name of the context to use in generating the text
|
||||
- instructions: The instructions to use in generating the text
|
||||
@@ -161,82 +169,82 @@ class ContextualGenerate(AgentNode):
|
||||
- uid: The uid to use in generating the text
|
||||
- context_aware: Whether to use the context in generating the text
|
||||
- history_aware: Whether to use the history in generating the text
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- state: The updated state of the graph
|
||||
- text: The generated text
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
context_type = PropertyField(
|
||||
name="context_type",
|
||||
description="The type of context to use in generating the text",
|
||||
type="str",
|
||||
choices=[
|
||||
"character attribute",
|
||||
"character detail",
|
||||
"character dialogue",
|
||||
"scene intro",
|
||||
"scene intent",
|
||||
"scene phase intent",
|
||||
"character attribute",
|
||||
"character detail",
|
||||
"character dialogue",
|
||||
"scene intro",
|
||||
"scene intent",
|
||||
"scene phase intent",
|
||||
"scene type description",
|
||||
"scene type instructions",
|
||||
"general",
|
||||
"list",
|
||||
"scene",
|
||||
"scene type instructions",
|
||||
"general",
|
||||
"list",
|
||||
"scene",
|
||||
"world context",
|
||||
],
|
||||
default="general"
|
||||
default="general",
|
||||
)
|
||||
context_name = PropertyField(
|
||||
name="context_name",
|
||||
description="The name of the context to use in generating the text",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
description="The instructions to use in generating the text",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
length = PropertyField(
|
||||
name="length",
|
||||
description="The length of the text to generate",
|
||||
type="int",
|
||||
default=100
|
||||
default=100,
|
||||
)
|
||||
character = PropertyField(
|
||||
name="character",
|
||||
description="The character to generate the text for",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
uid = PropertyField(
|
||||
name="uid",
|
||||
description="The uid to use in generating the text",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
context_aware = PropertyField(
|
||||
name="context_aware",
|
||||
description="Whether to use the context in generating the text",
|
||||
type="bool",
|
||||
default=True
|
||||
default=True,
|
||||
)
|
||||
history_aware = PropertyField(
|
||||
name="history_aware",
|
||||
description="Whether to use the history in generating the text",
|
||||
type="bool",
|
||||
default=True
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __init__(self, title="Contextual Generate", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("context_type", socket_type="str", optional=True)
|
||||
@@ -247,8 +255,10 @@ class ContextualGenerate(AgentNode):
|
||||
self.add_input("original", socket_type="str", optional=True)
|
||||
self.add_input("partial", socket_type="str", optional=True)
|
||||
self.add_input("uid", socket_type="str", optional=True)
|
||||
self.add_input("generation_options", socket_type="generation_options", optional=True)
|
||||
|
||||
self.add_input(
|
||||
"generation_options", socket_type="generation_options", optional=True
|
||||
)
|
||||
|
||||
self.set_property("context_type", "general")
|
||||
self.set_property("context_name", UNRESOLVED)
|
||||
self.set_property("instructions", UNRESOLVED)
|
||||
@@ -257,10 +267,10 @@ class ContextualGenerate(AgentNode):
|
||||
self.set_property("uid", UNRESOLVED)
|
||||
self.set_property("context_aware", True)
|
||||
self.set_property("history_aware", True)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("text", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
scene = active_scene.get()
|
||||
context_type = self.require_input("context_type")
|
||||
@@ -274,12 +284,12 @@ class ContextualGenerate(AgentNode):
|
||||
generation_options = self.normalized_input_value("generation_options")
|
||||
context_aware = self.normalized_input_value("context_aware")
|
||||
history_aware = self.normalized_input_value("history_aware")
|
||||
|
||||
|
||||
context = f"{context_type}:{context_name}" if context_name else context_type
|
||||
|
||||
|
||||
if isinstance(character, scene.Character):
|
||||
character = character.name
|
||||
|
||||
|
||||
text = await self.agent.contextual_generate_from_args(
|
||||
context=context,
|
||||
instructions=instructions,
|
||||
@@ -288,31 +298,32 @@ class ContextualGenerate(AgentNode):
|
||||
original=original,
|
||||
partial=partial or "",
|
||||
uid=uid,
|
||||
writing_style=generation_options.writing_style if generation_options else None,
|
||||
writing_style=generation_options.writing_style
|
||||
if generation_options
|
||||
else None,
|
||||
spices=generation_options.spices if generation_options else None,
|
||||
spice_level=generation_options.spice_level if generation_options else 0.0,
|
||||
context_aware=context_aware,
|
||||
history_aware=history_aware
|
||||
history_aware=history_aware,
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"text": text
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"state": state, "text": text})
|
||||
|
||||
|
||||
@register("agents/creator/GenerateThematicList")
|
||||
class GenerateThematicList(AgentNode):
|
||||
"""
|
||||
Generates a list of thematic items based on the instructions.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "creator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "creator"
|
||||
|
||||
class Fields:
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
description="The instructions to use in generating the list",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
iterations = PropertyField(
|
||||
name="iterations",
|
||||
@@ -323,27 +334,24 @@ class GenerateThematicList(AgentNode):
|
||||
min=1,
|
||||
max=10,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Generate Thematic List", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("instructions", socket_type="str", optional=True)
|
||||
|
||||
self.set_property("instructions", UNRESOLVED)
|
||||
self.set_property("iterations", 1)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("list", socket_type="list")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
iterations = self.require_number_input("iterations")
|
||||
|
||||
|
||||
list = await self.agent.generate_thematic_list(instructions, iterations)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"list": list
|
||||
})
|
||||
|
||||
self.set_output_values({"state": state, "list": list})
|
||||
|
||||
@@ -10,7 +10,7 @@ class ScenarioCreatorMixin:
|
||||
@set_processing
|
||||
async def determine_scenario_description(self, text: str):
|
||||
description = await Prompt.request(
|
||||
f"creator.determine-scenario-description",
|
||||
"creator.determine-scenario-description",
|
||||
self.client,
|
||||
"analyze_long",
|
||||
vars={
|
||||
@@ -25,7 +25,7 @@ class ScenarioCreatorMixin:
|
||||
description: str,
|
||||
):
|
||||
content_context = await Prompt.request(
|
||||
f"creator.determine-content-context",
|
||||
"creator.determine-content-context",
|
||||
self.client,
|
||||
"create_short",
|
||||
vars={
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import structlog
|
||||
import traceback
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.instance as instance
|
||||
from talemate.agents.conversation import ConversationAgentEmission
|
||||
from talemate.emit import emit
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import DirectorMessage
|
||||
from talemate.util import random_color
|
||||
from talemate.character import deactivate_character
|
||||
@@ -27,7 +23,7 @@ from .legacy_scene_instructions import LegacySceneInstructionsMixin
|
||||
from .auto_direct import AutoDirectMixin
|
||||
from .websocket_handler import DirectorWebsocketHandler
|
||||
|
||||
import talemate.agents.director.nodes
|
||||
import talemate.agents.director.nodes # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate import Character, Scene
|
||||
@@ -42,7 +38,7 @@ class DirectorAgent(
|
||||
GenerateChoicesMixin,
|
||||
AutoDirectMixin,
|
||||
LegacySceneInstructionsMixin,
|
||||
Agent
|
||||
Agent,
|
||||
):
|
||||
agent_type = "director"
|
||||
verbose_name = "Director"
|
||||
@@ -74,7 +70,7 @@ class DirectorAgent(
|
||||
],
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
}
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
GenerateChoicesMixin.add_actions(actions)
|
||||
@@ -112,7 +108,6 @@ class DirectorAgent(
|
||||
created_characters = []
|
||||
|
||||
for character_name in self.scene.world_state.characters.keys():
|
||||
|
||||
if exclude and character_name.lower() in exclude:
|
||||
continue
|
||||
|
||||
@@ -148,13 +143,12 @@ class DirectorAgent(
|
||||
memory = instance.get_agent("memory")
|
||||
scene: "Scene" = self.scene
|
||||
any_attribute_templates = False
|
||||
|
||||
|
||||
loading_status = LoadingStatus(max_steps=None, cancellable=True)
|
||||
|
||||
|
||||
# Start of character creation
|
||||
log.debug("persist_character", name=name)
|
||||
|
||||
|
||||
# Determine the character's name (or clarify if it's already set)
|
||||
if determine_name:
|
||||
loading_status("Determining character name")
|
||||
@@ -162,7 +156,7 @@ class DirectorAgent(
|
||||
log.debug("persist_character", adjusted_name=name)
|
||||
|
||||
# Create the blank character
|
||||
character:Character = self.scene.Character(name=name)
|
||||
character: Character = self.scene.Character(name=name)
|
||||
|
||||
# Add the character to the scene
|
||||
character.color = random_color()
|
||||
@@ -170,32 +164,44 @@ class DirectorAgent(
|
||||
character=character, agent=instance.get_agent("conversation")
|
||||
)
|
||||
await self.scene.add_actor(actor)
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
# Apply any character generation templates
|
||||
if templates:
|
||||
loading_status("Applying character generation templates")
|
||||
templates = scene.world_state_manager.template_collection.collect_all(templates)
|
||||
templates = scene.world_state_manager.template_collection.collect_all(
|
||||
templates
|
||||
)
|
||||
log.debug("persist_character", applying_templates=templates)
|
||||
await scene.world_state_manager.apply_templates(
|
||||
templates.values(),
|
||||
templates.values(),
|
||||
character_name=character.name,
|
||||
information=content
|
||||
information=content,
|
||||
)
|
||||
|
||||
|
||||
# if any of the templates are attribute templates, then we no longer need to
|
||||
# generate a character sheet
|
||||
any_attribute_templates = any(template.template_type == "character_attribute" for template in templates.values())
|
||||
log.debug("persist_character", any_attribute_templates=any_attribute_templates)
|
||||
|
||||
if any_attribute_templates and augment_attributes and generate_attributes:
|
||||
log.debug("persist_character", augmenting_attributes=augment_attributes)
|
||||
any_attribute_templates = any(
|
||||
template.template_type == "character_attribute"
|
||||
for template in templates.values()
|
||||
)
|
||||
log.debug(
|
||||
"persist_character", any_attribute_templates=any_attribute_templates
|
||||
)
|
||||
|
||||
if (
|
||||
any_attribute_templates
|
||||
and augment_attributes
|
||||
and generate_attributes
|
||||
):
|
||||
log.debug(
|
||||
"persist_character", augmenting_attributes=augment_attributes
|
||||
)
|
||||
loading_status("Augmenting character attributes")
|
||||
additional_attributes = await world_state.extract_character_sheet(
|
||||
name=name,
|
||||
text=content,
|
||||
augmentation_instructions=augment_attributes
|
||||
augmentation_instructions=augment_attributes,
|
||||
)
|
||||
character.base_attributes.update(additional_attributes)
|
||||
|
||||
@@ -212,25 +218,26 @@ class DirectorAgent(
|
||||
|
||||
log.debug("persist_character", attributes=attributes)
|
||||
character.base_attributes = attributes
|
||||
|
||||
|
||||
# Generate a description for the character
|
||||
if not description:
|
||||
loading_status("Generating character description")
|
||||
description = await creator.determine_character_description(character, information=content)
|
||||
description = await creator.determine_character_description(
|
||||
character, information=content
|
||||
)
|
||||
character.description = description
|
||||
log.debug("persist_character", description=description)
|
||||
|
||||
# Generate a dialogue instructions for the character
|
||||
loading_status("Generating acting instructions")
|
||||
dialogue_instructions = await creator.determine_character_dialogue_instructions(
|
||||
character,
|
||||
information=content
|
||||
dialogue_instructions = (
|
||||
await creator.determine_character_dialogue_instructions(
|
||||
character, information=content
|
||||
)
|
||||
)
|
||||
character.dialogue_instructions = dialogue_instructions
|
||||
log.debug(
|
||||
"persist_character", dialogue_instructions=dialogue_instructions
|
||||
)
|
||||
|
||||
log.debug("persist_character", dialogue_instructions=dialogue_instructions)
|
||||
|
||||
# Narrate the character's entry if the option is selected
|
||||
if active and narrate_entry:
|
||||
loading_status("Narrating character entry")
|
||||
@@ -240,24 +247,26 @@ class DirectorAgent(
|
||||
"narrate_character_entry",
|
||||
emit_message=True,
|
||||
character=character,
|
||||
narrative_direction=narrate_entry_direction
|
||||
narrative_direction=narrate_entry_direction,
|
||||
)
|
||||
|
||||
|
||||
# Deactivate the character if not active
|
||||
if not active:
|
||||
await deactivate_character(scene, character)
|
||||
|
||||
|
||||
# Commit the character's details to long term memory
|
||||
await character.commit_to_memory(memory)
|
||||
self.scene.emit_status()
|
||||
self.scene.world_state.emit()
|
||||
|
||||
loading_status.done(message=f"{character.name} added to scene", status="success")
|
||||
|
||||
loading_status.done(
|
||||
message=f"{character.name} added to scene", status="success"
|
||||
)
|
||||
return character
|
||||
except GenerationCancelled:
|
||||
loading_status.done(message="Character creation cancelled", status="idle")
|
||||
await scene.remove_actor(actor)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
loading_status.done(message="Character creation failed", status="error")
|
||||
await scene.remove_actor(actor)
|
||||
log.error("Error persisting character", error=traceback.format_exc())
|
||||
@@ -276,4 +285,4 @@ class DirectorAgent(
|
||||
def allow_repetition_break(
|
||||
self, kind: str, agent_function_name: str, auto: bool = False
|
||||
):
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
import pydantic
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentEmission,
|
||||
AgentTemplateEmission,
|
||||
)
|
||||
from talemate.status import set_loading
|
||||
import talemate.game.focal as focal
|
||||
from talemate.prompts import Prompt
|
||||
import talemate.emit.async_signals as async_signals
|
||||
from talemate.scene_message import CharacterMessage, TimePassageMessage, DirectorMessage, NarratorMessage
|
||||
from talemate.scene.schema import ScenePhase, SceneType, SceneIntent
|
||||
from talemate.scene_message import (
|
||||
CharacterMessage,
|
||||
TimePassageMessage,
|
||||
NarratorMessage,
|
||||
)
|
||||
from talemate.scene.schema import ScenePhase, SceneType
|
||||
from talemate.scene.intent import set_scene_phase
|
||||
from talemate.world_state.manager import WorldStateManager
|
||||
from talemate.world_state.templates.scene import SceneType as TemplateSceneType
|
||||
import talemate.agents.director.auto_direct_nodes
|
||||
import talemate.agents.director.auto_direct_nodes # noqa: F401
|
||||
from talemate.world_state.templates.base import TypedCollection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,15 +25,15 @@ if TYPE_CHECKING:
|
||||
log = structlog.get_logger("talemate.agents.conversation.direct")
|
||||
|
||||
|
||||
#talemate.emit.async_signals.register(
|
||||
#)
|
||||
# talemate.emit.async_signals.register(
|
||||
# )
|
||||
|
||||
|
||||
class AutoDirectMixin:
|
||||
|
||||
"""
|
||||
Director agent mixin for automatic scene direction.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["auto_direct"] = AgentAction(
|
||||
@@ -108,113 +107,117 @@ class AutoDirectMixin:
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_enabled(self) -> bool:
|
||||
return self.actions["auto_direct"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_max_auto_turns(self) -> int:
|
||||
return self.actions["auto_direct"].config["max_auto_turns"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_max_idle_turns(self) -> int:
|
||||
return self.actions["auto_direct"].config["max_idle_turns"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_max_repeat_turns(self) -> int:
|
||||
return self.actions["auto_direct"].config["max_repeat_turns"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_actors(self) -> bool:
|
||||
return self.actions["auto_direct"].config["instruct_actors"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_narrator(self) -> bool:
|
||||
return self.actions["auto_direct"].config["instruct_narrator"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_frequency(self) -> int:
|
||||
return self.actions["auto_direct"].config["instruct_frequency"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_evaluate_scene_intention(self) -> int:
|
||||
return self.actions["auto_direct"].config["evaluate_scene_intention"].value
|
||||
|
||||
|
||||
@property
|
||||
def auto_direct_instruct_any(self) -> bool:
|
||||
"""
|
||||
Will check whether actor or narrator instructions are enabled.
|
||||
|
||||
|
||||
For narrator instructions to be enabled instruct_narrator needs to be enabled as well.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if either actor or narrator instructions are enabled.
|
||||
"""
|
||||
|
||||
|
||||
return self.auto_direct_instruct_actors or self.auto_direct_instruct_narrator
|
||||
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
|
||||
|
||||
async def on_game_loop(self, event):
|
||||
if not self.auto_direct_enabled:
|
||||
return
|
||||
|
||||
|
||||
if self.auto_direct_evaluate_scene_intention > 0:
|
||||
evaluation_due = self.get_scene_state("evaluated_scene_intention", 0)
|
||||
if evaluation_due == 0:
|
||||
await self.auto_direct_set_scene_intent()
|
||||
self.set_scene_states(evaluated_scene_intention=self.auto_direct_evaluate_scene_intention)
|
||||
self.set_scene_states(
|
||||
evaluated_scene_intention=self.auto_direct_evaluate_scene_intention
|
||||
)
|
||||
else:
|
||||
self.set_scene_states(evaluated_scene_intention=evaluation_due - 1)
|
||||
|
||||
# helpers
|
||||
|
||||
def auto_direct_is_due_for_instruction(self, actor_name:str) -> bool:
|
||||
|
||||
def auto_direct_is_due_for_instruction(self, actor_name: str) -> bool:
|
||||
"""
|
||||
Check if the actor is due for instruction.
|
||||
"""
|
||||
|
||||
|
||||
if self.auto_direct_instruct_frequency == 1:
|
||||
return True
|
||||
|
||||
|
||||
messages_since_last_instruction = 0
|
||||
|
||||
|
||||
def count_messages(message):
|
||||
nonlocal messages_since_last_instruction
|
||||
if message.typ in ["character", "narrator"]:
|
||||
messages_since_last_instruction += 1
|
||||
|
||||
|
||||
|
||||
last_instruction = self.scene.last_message_of_type(
|
||||
"director",
|
||||
character_name=actor_name,
|
||||
max_iterations=25,
|
||||
on_iterate=count_messages,
|
||||
)
|
||||
|
||||
log.debug("auto_direct_is_due_for_instruction", messages_since_last_instruction=messages_since_last_instruction, last_instruction=last_instruction.id if last_instruction else None)
|
||||
|
||||
|
||||
log.debug(
|
||||
"auto_direct_is_due_for_instruction",
|
||||
messages_since_last_instruction=messages_since_last_instruction,
|
||||
last_instruction=last_instruction.id if last_instruction else None,
|
||||
)
|
||||
|
||||
if not last_instruction:
|
||||
return True
|
||||
|
||||
|
||||
return messages_since_last_instruction >= self.auto_direct_instruct_frequency
|
||||
|
||||
|
||||
def auto_direct_candidates(self) -> list["Character"]:
|
||||
"""
|
||||
Returns a list of characters who are valid candidates to speak next.
|
||||
based on the max_idle_turns, max_repeat_turns, and the most recent character.
|
||||
"""
|
||||
|
||||
scene:"Scene" = self.scene
|
||||
|
||||
|
||||
scene: "Scene" = self.scene
|
||||
|
||||
most_recent_character = None
|
||||
repeat_count = 0
|
||||
last_player_turn = None
|
||||
@@ -223,89 +226,105 @@ class AutoDirectMixin:
|
||||
active_charcters = list(scene.characters)
|
||||
active_character_names = [character.name for character in active_charcters]
|
||||
instruct_narrator = self.auto_direct_instruct_narrator
|
||||
|
||||
|
||||
# if there is only one character then they are the only candidate
|
||||
if len(active_charcters) == 1:
|
||||
return active_charcters
|
||||
|
||||
|
||||
BACKLOG_LIMIT = 50
|
||||
|
||||
|
||||
player_character_active = scene.player_character_exists
|
||||
|
||||
|
||||
# check the last BACKLOG_LIMIT entries in the scene history and collect into
|
||||
# a dictionary of character names and the number of turns since they last spoke.
|
||||
|
||||
|
||||
len_history = len(scene.history) - 1
|
||||
num = 0
|
||||
for idx in range(len_history, -1, -1):
|
||||
message = scene.history[idx]
|
||||
turns = len_history - idx
|
||||
|
||||
|
||||
num += 1
|
||||
|
||||
|
||||
if num > BACKLOG_LIMIT:
|
||||
break
|
||||
|
||||
|
||||
if isinstance(message, TimePassageMessage):
|
||||
break
|
||||
|
||||
|
||||
if not isinstance(message, (CharacterMessage, NarratorMessage)):
|
||||
continue
|
||||
|
||||
|
||||
# if character message but character is not in the active characters list then skip
|
||||
if isinstance(message, CharacterMessage) and message.character_name not in active_character_names:
|
||||
if (
|
||||
isinstance(message, CharacterMessage)
|
||||
and message.character_name not in active_character_names
|
||||
):
|
||||
continue
|
||||
|
||||
|
||||
if isinstance(message, NarratorMessage):
|
||||
if not instruct_narrator:
|
||||
continue
|
||||
character = scene.narrator_character_object
|
||||
else:
|
||||
character = scene.get_character(message.character_name)
|
||||
|
||||
|
||||
if not character:
|
||||
continue
|
||||
|
||||
|
||||
if character.is_player and last_player_turn is None:
|
||||
last_player_turn = turns
|
||||
elif not character.is_player and last_player_turn is None:
|
||||
consecutive_auto_turns += 1
|
||||
|
||||
|
||||
if not most_recent_character:
|
||||
most_recent_character = character
|
||||
repeat_count += 1
|
||||
elif character == most_recent_character:
|
||||
repeat_count += 1
|
||||
|
||||
|
||||
if character.name not in candidates:
|
||||
candidates[character.name] = turns
|
||||
|
||||
|
||||
# add any characters that have not spoken yet
|
||||
for character in active_charcters:
|
||||
if character.name not in candidates:
|
||||
candidates[character.name] = 0
|
||||
|
||||
|
||||
# explicitly add narrator if enabled and not already in candidates
|
||||
if instruct_narrator and scene.narrator_character_object:
|
||||
narrator = scene.narrator_character_object
|
||||
if narrator.name not in candidates:
|
||||
candidates[narrator.name] = 0
|
||||
|
||||
log.debug(f"auto_direct_candidates: {candidates}", most_recent_character=most_recent_character, repeat_count=repeat_count, last_player_turn=last_player_turn, consecutive_auto_turns=consecutive_auto_turns)
|
||||
|
||||
|
||||
log.debug(
|
||||
f"auto_direct_candidates: {candidates}",
|
||||
most_recent_character=most_recent_character,
|
||||
repeat_count=repeat_count,
|
||||
last_player_turn=last_player_turn,
|
||||
consecutive_auto_turns=consecutive_auto_turns,
|
||||
)
|
||||
|
||||
if not most_recent_character:
|
||||
log.debug("auto_direct_candidates: No most recent character found.")
|
||||
return list(scene.characters)
|
||||
|
||||
|
||||
# if player has not spoken in a while then they are favored
|
||||
if player_character_active and consecutive_auto_turns >= self.auto_direct_max_auto_turns:
|
||||
log.debug("auto_direct_candidates: User controlled character has not spoken in a while.")
|
||||
if (
|
||||
player_character_active
|
||||
and consecutive_auto_turns >= self.auto_direct_max_auto_turns
|
||||
):
|
||||
log.debug(
|
||||
"auto_direct_candidates: User controlled character has not spoken in a while."
|
||||
)
|
||||
return [scene.get_player_character()]
|
||||
|
||||
|
||||
# check if most recent character has spoken too many times in a row
|
||||
# if so then remove from candidates
|
||||
if repeat_count >= self.auto_direct_max_repeat_turns:
|
||||
log.debug("auto_direct_candidates: Most recent character has spoken too many times in a row.", most_recent_character=most_recent_character
|
||||
log.debug(
|
||||
"auto_direct_candidates: Most recent character has spoken too many times in a row.",
|
||||
most_recent_character=most_recent_character,
|
||||
)
|
||||
candidates.pop(most_recent_character.name, None)
|
||||
|
||||
@@ -314,27 +333,34 @@ class AutoDirectMixin:
|
||||
favored_candidates = []
|
||||
for name, turns in candidates.items():
|
||||
if turns > self.auto_direct_max_idle_turns:
|
||||
log.debug("auto_direct_candidates: Character has gone too long without speaking.", character_name=name, turns=turns)
|
||||
log.debug(
|
||||
"auto_direct_candidates: Character has gone too long without speaking.",
|
||||
character_name=name,
|
||||
turns=turns,
|
||||
)
|
||||
favored_candidates.append(scene.get_character(name))
|
||||
|
||||
|
||||
if favored_candidates:
|
||||
return favored_candidates
|
||||
|
||||
return [scene.get_character(character_name) for character_name in candidates.keys()]
|
||||
|
||||
return [
|
||||
scene.get_character(character_name) for character_name in candidates.keys()
|
||||
]
|
||||
|
||||
# actions
|
||||
|
||||
|
||||
@set_processing
|
||||
async def auto_direct_set_scene_intent(self, require:bool=False) -> ScenePhase | None:
|
||||
|
||||
async def set_scene_intention(type:str, intention:str) -> ScenePhase:
|
||||
async def auto_direct_set_scene_intent(
|
||||
self, require: bool = False
|
||||
) -> ScenePhase | None:
|
||||
async def set_scene_intention(type: str, intention: str) -> ScenePhase:
|
||||
await set_scene_phase(self.scene, type, intention)
|
||||
self.scene.emit_status()
|
||||
return self.scene.intent_state.phase
|
||||
|
||||
|
||||
async def do_nothing(*args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
|
||||
focal_handler = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
@@ -355,57 +381,63 @@ class AutoDirectMixin:
|
||||
],
|
||||
max_calls=1,
|
||||
scene=self.scene,
|
||||
scene_type_ids=", ".join([f'"{scene_type.id}"' for scene_type in self.scene.intent_state.scene_types.values()]),
|
||||
scene_type_ids=", ".join(
|
||||
[
|
||||
f'"{scene_type.id}"'
|
||||
for scene_type in self.scene.intent_state.scene_types.values()
|
||||
]
|
||||
),
|
||||
retries=1,
|
||||
require=require,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"director.direct-determine-scene-intent",
|
||||
)
|
||||
|
||||
|
||||
return self.scene.intent_state.phase
|
||||
|
||||
|
||||
@set_processing
|
||||
async def auto_direct_generate_scene_types(
|
||||
self,
|
||||
instructions:str,
|
||||
max_scene_types:int=1,
|
||||
self,
|
||||
instructions: str,
|
||||
max_scene_types: int = 1,
|
||||
):
|
||||
|
||||
world_state_manager:WorldStateManager = self.scene.world_state_manager
|
||||
|
||||
scene_type_templates:TypedCollection = await world_state_manager.get_templates(types=["scene_type"])
|
||||
|
||||
async def add_from_template(id:str) -> SceneType:
|
||||
template:TemplateSceneType | None = scene_type_templates.find_by_name(id)
|
||||
world_state_manager: WorldStateManager = self.scene.world_state_manager
|
||||
|
||||
scene_type_templates: TypedCollection = await world_state_manager.get_templates(
|
||||
types=["scene_type"]
|
||||
)
|
||||
|
||||
async def add_from_template(id: str) -> SceneType:
|
||||
template: TemplateSceneType | None = scene_type_templates.find_by_name(id)
|
||||
if not template:
|
||||
log.warning("auto_direct_generate_scene_types: Template not found.", name=id)
|
||||
log.warning(
|
||||
"auto_direct_generate_scene_types: Template not found.", name=id
|
||||
)
|
||||
return None
|
||||
return template.apply_to_scene(self.scene)
|
||||
|
||||
|
||||
async def generate_scene_type(
|
||||
id:str = None,
|
||||
name:str = None,
|
||||
description:str = None,
|
||||
instructions:str = None,
|
||||
id: str = None,
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
instructions: str = None,
|
||||
) -> SceneType:
|
||||
|
||||
if not id or not name:
|
||||
return None
|
||||
|
||||
|
||||
scene_type = SceneType(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
|
||||
self.scene.intent_state.scene_types[id] = scene_type
|
||||
|
||||
|
||||
return scene_type
|
||||
|
||||
|
||||
|
||||
focal_handler = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
@@ -435,7 +467,7 @@ class AutoDirectMixin:
|
||||
instructions=instructions,
|
||||
scene_type_templates=scene_type_templates.templates,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"director.generate-scene-types",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
from typing import ClassVar
|
||||
from talemate.game.engine.nodes.core import GraphState, PropertyField
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentNode
|
||||
from talemate.scene.schema import ScenePhase
|
||||
from talemate.context import active_scene
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
from talemate.agents.director import DirectorAgent
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/Candidates")
|
||||
class AutoDirectCandidates(AgentNode):
|
||||
"""
|
||||
@@ -19,52 +15,51 @@ class AutoDirectCandidates(AgentNode):
|
||||
next action, based on the director's auto-direct settings and
|
||||
the recent scene history.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Auto Direct Candidates", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_output("characters", socket_type="list")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
candidates = self.agent.auto_direct_candidates()
|
||||
self.set_output_values({
|
||||
"characters": candidates
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"characters": candidates})
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/DetermineSceneIntent")
|
||||
class DetermineSceneIntent(AgentNode):
|
||||
"""
|
||||
Determines the scene intent based on the current scene state.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Determine Scene Intent", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_output("state")
|
||||
self.add_output("scene_phase", socket_type="scene_intent/scene_phase")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
phase:ScenePhase = await self.agent.auto_direct_set_scene_intent()
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"scene_phase": phase
|
||||
})
|
||||
phase: ScenePhase = await self.agent.auto_direct_set_scene_intent()
|
||||
|
||||
self.set_output_values({"state": state, "scene_phase": phase})
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/GenerateSceneTypes")
|
||||
class GenerateSceneTypes(AgentNode):
|
||||
"""
|
||||
Generates scene types based on the current scene state.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
@@ -72,17 +67,17 @@ class GenerateSceneTypes(AgentNode):
|
||||
description="The instructions for the scene types",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
max_scene_types = PropertyField(
|
||||
name="max_scene_types",
|
||||
type="int",
|
||||
description="The maximum number of scene types to generate",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Generate Scene Types", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("instructions", socket_type="str", optional=True)
|
||||
@@ -90,29 +85,26 @@ class GenerateSceneTypes(AgentNode):
|
||||
self.set_property("instructions", "")
|
||||
self.set_property("max_scene_types", 1)
|
||||
self.add_output("state")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
max_scene_types = self.normalized_input_value("max_scene_types")
|
||||
|
||||
|
||||
scene_types = await self.agent.auto_direct_generate_scene_types(
|
||||
instructions=instructions,
|
||||
max_scene_types=max_scene_types
|
||||
instructions=instructions, max_scene_types=max_scene_types
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"scene_types": scene_types
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"state": state, "scene_types": scene_types})
|
||||
|
||||
|
||||
@register("agents/director/auto-direct/IsDueForInstruction")
|
||||
class IsDueForInstruction(AgentNode):
|
||||
"""
|
||||
Checks if the actor is due for instruction based on the auto-direct settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
actor_name = PropertyField(
|
||||
name="actor_name",
|
||||
@@ -120,24 +112,21 @@ class IsDueForInstruction(AgentNode):
|
||||
description="The name of the actor to check instruction timing for",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Is Due For Instruction", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("actor_name", socket_type="str")
|
||||
|
||||
|
||||
self.set_property("actor_name", "")
|
||||
|
||||
|
||||
self.add_output("is_due", socket_type="bool")
|
||||
self.add_output("actor_name", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
actor_name = self.require_input("actor_name")
|
||||
|
||||
|
||||
is_due = self.agent.auto_direct_is_due_for_instruction(actor_name)
|
||||
|
||||
self.set_output_values({
|
||||
"is_due": is_due,
|
||||
"actor_name": actor_name
|
||||
})
|
||||
|
||||
self.set_output_values({"is_due": is_due, "actor_name": actor_name})
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import random
|
||||
import structlog
|
||||
from functools import wraps
|
||||
import dataclasses
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentTemplateEmission,
|
||||
DynamicInstruction,
|
||||
)
|
||||
from talemate.events import GameLoopStartEvent
|
||||
from talemate.scene_message import NarratorMessage, CharacterMessage
|
||||
@@ -24,7 +22,7 @@ __all__ = [
|
||||
log = structlog.get_logger()
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.director.generate_choices.before_generate",
|
||||
"agent.director.generate_choices.before_generate",
|
||||
"agent.director.generate_choices.inject_instructions",
|
||||
"agent.director.generate_choices.generated",
|
||||
)
|
||||
@@ -32,18 +30,19 @@ talemate.emit.async_signals.register(
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerateChoicesEmission(AgentTemplateEmission):
|
||||
character: "Character | None" = None
|
||||
choices: list[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class GenerateChoicesMixin:
|
||||
|
||||
"""
|
||||
Director agent mixin that provides functionality for automatically guiding
|
||||
the actors or the narrator during the scene progression.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["_generate_choices"] = AgentAction(
|
||||
@@ -65,7 +64,6 @@ class GenerateChoicesMixin:
|
||||
max=1,
|
||||
step=0.1,
|
||||
),
|
||||
|
||||
"num_choices": AgentActionConfig(
|
||||
type="number",
|
||||
label="Number of Actions",
|
||||
@@ -75,33 +73,31 @@ class GenerateChoicesMixin:
|
||||
max=10,
|
||||
step=1,
|
||||
),
|
||||
|
||||
"never_auto_progress": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Never Auto Progress on Action Selection",
|
||||
description="If enabled, the scene will not auto progress after you select an action.",
|
||||
value=False,
|
||||
),
|
||||
|
||||
"instructions": AgentActionConfig(
|
||||
type="blob",
|
||||
label="Instructions",
|
||||
description="Provide some instructions to the director for generating actions.",
|
||||
value="",
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def generate_choices_enabled(self):
|
||||
return self.actions["_generate_choices"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def generate_choices_chance(self):
|
||||
return self.actions["_generate_choices"].config["chance"].value
|
||||
|
||||
|
||||
@property
|
||||
def generate_choices_num_choices(self):
|
||||
return self.actions["_generate_choices"].config["num_choices"].value
|
||||
@@ -115,24 +111,25 @@ class GenerateChoicesMixin:
|
||||
return self.actions["_generate_choices"].config["instructions"].value
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("player_turn_start").connect(self.on_player_turn_start)
|
||||
|
||||
talemate.emit.async_signals.get("player_turn_start").connect(
|
||||
self.on_player_turn_start
|
||||
)
|
||||
|
||||
async def on_player_turn_start(self, event: GameLoopStartEvent):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
|
||||
if self.generate_choices_enabled:
|
||||
|
||||
# look backwards through history and abort if we encounter
|
||||
# a character message with source "player" before either
|
||||
# a character message with a different source or a narrator message
|
||||
#
|
||||
# this is so choices aren't generated when the player message was
|
||||
# the most recent content in the scene
|
||||
|
||||
|
||||
for i in range(len(self.scene.history) - 1, -1, -1):
|
||||
message = self.scene.history[i]
|
||||
if isinstance(message, NarratorMessage):
|
||||
@@ -141,12 +138,11 @@ class GenerateChoicesMixin:
|
||||
if message.source == "player":
|
||||
return
|
||||
break
|
||||
|
||||
|
||||
if random.random() < self.generate_choices_chance:
|
||||
await self.generate_choices()
|
||||
|
||||
await self.generate_choices()
|
||||
|
||||
# methods
|
||||
|
||||
|
||||
@set_processing
|
||||
async def generate_choices(
|
||||
@@ -154,20 +150,23 @@ class GenerateChoicesMixin:
|
||||
instructions: str = None,
|
||||
character: "Character | str | None" = None,
|
||||
):
|
||||
|
||||
emission: GenerateChoicesEmission = GenerateChoicesEmission(agent=self)
|
||||
|
||||
if isinstance(character, str):
|
||||
character = self.scene.get_character(character)
|
||||
|
||||
|
||||
if not character:
|
||||
character = self.scene.get_player_character()
|
||||
|
||||
|
||||
emission.character = character
|
||||
|
||||
await talemate.emit.async_signals.get("agent.director.generate_choices.before_generate").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.director.generate_choices.inject_instructions").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.generate_choices.before_generate"
|
||||
).send(emission)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.generate_choices.inject_instructions"
|
||||
).send(emission)
|
||||
|
||||
response = await Prompt.request(
|
||||
"director.generate-choices",
|
||||
self.client,
|
||||
@@ -178,7 +177,9 @@ class GenerateChoicesMixin:
|
||||
"character": character,
|
||||
"num_choices": self.generate_choices_num_choices,
|
||||
"instructions": instructions or self.generate_choices_instructions,
|
||||
"dynamic_instructions": emission.dynamic_instructions if emission else None,
|
||||
"dynamic_instructions": emission.dynamic_instructions
|
||||
if emission
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -187,10 +188,10 @@ class GenerateChoicesMixin:
|
||||
choices = util.extract_list(choice_text)
|
||||
# strip quotes
|
||||
choices = [choice.strip().strip('"') for choice in choices]
|
||||
|
||||
|
||||
# limit to num_choices
|
||||
choices = choices[:self.generate_choices_num_choices]
|
||||
|
||||
choices = choices[: self.generate_choices_num_choices]
|
||||
|
||||
except Exception as e:
|
||||
log.error("generate_choices failed", error=str(e), response=response)
|
||||
return
|
||||
@@ -198,15 +199,17 @@ class GenerateChoicesMixin:
|
||||
emit(
|
||||
"player_choice",
|
||||
response,
|
||||
data = {
|
||||
data={
|
||||
"choices": choices,
|
||||
"character": character.name,
|
||||
},
|
||||
websocket_passthrough=True
|
||||
websocket_passthrough=True,
|
||||
)
|
||||
|
||||
|
||||
emission.response = response
|
||||
emission.choices = choices
|
||||
await talemate.emit.async_signals.get("agent.director.generate_choices.generated").send(emission)
|
||||
|
||||
return emission.response
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.generate_choices.generated"
|
||||
).send(emission)
|
||||
|
||||
return emission.response
|
||||
|
||||
@@ -17,12 +17,11 @@ from talemate.util import strip_partial_sentences
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Character
|
||||
from talemate.agents.summarize.analyze_scene import SceneAnalysisEmission
|
||||
from talemate.agents.editor.revision import RevisionAnalysisEmission
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.director.guide.before_generate",
|
||||
"agent.director.guide.before_generate",
|
||||
"agent.director.guide.inject_instructions",
|
||||
"agent.director.guide.generated",
|
||||
)
|
||||
@@ -32,6 +31,7 @@ talemate.emit.async_signals.register(
|
||||
class DirectorGuidanceEmission(AgentTemplateEmission):
|
||||
pass
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
Custom decorator that emits the agent status as processing while the function
|
||||
@@ -42,29 +42,33 @@ def set_processing(fn):
|
||||
@wraps(fn)
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
emission: DirectorGuidanceEmission = DirectorGuidanceEmission(agent=self)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.director.guide.before_generate").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.director.guide.inject_instructions").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.guide.before_generate"
|
||||
).send(emission)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.director.guide.inject_instructions"
|
||||
).send(emission)
|
||||
|
||||
agent_context = active_agent.get()
|
||||
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission.response = response
|
||||
await talemate.emit.async_signals.get("agent.director.guide.generated").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.director.guide.generated").send(
|
||||
emission
|
||||
)
|
||||
return emission.response
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
class GuideSceneMixin:
|
||||
|
||||
"""
|
||||
Director agent mixin that provides functionality for automatically guiding
|
||||
the actors or the narrator during the scene progression.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["guide_scene"] = AgentAction(
|
||||
@@ -81,13 +85,13 @@ class GuideSceneMixin:
|
||||
type="bool",
|
||||
label="Guide actors",
|
||||
description="Guide the actors in the scene. This happens during every actor turn.",
|
||||
value=True
|
||||
value=True,
|
||||
),
|
||||
"guide_narrator": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Guide narrator",
|
||||
description="Guide the narrator during the scene. This happens during the narrator's turn.",
|
||||
value=True
|
||||
value=True,
|
||||
),
|
||||
"guidance_length": AgentActionConfig(
|
||||
type="text",
|
||||
@@ -101,7 +105,7 @@ class GuideSceneMixin:
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Medium Long (768)", "value": "768"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
]
|
||||
],
|
||||
),
|
||||
"cache_guidance": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -109,57 +113,57 @@ class GuideSceneMixin:
|
||||
description="Will not regenerate the guidance until the scene moves forward or the analysis changes.",
|
||||
value=False,
|
||||
quick_toggle=True,
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def guide_scene(self) -> bool:
|
||||
return self.actions["guide_scene"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def guide_actors(self) -> bool:
|
||||
return self.actions["guide_scene"].config["guide_actors"].value
|
||||
|
||||
|
||||
@property
|
||||
def guide_narrator(self) -> bool:
|
||||
return self.actions["guide_scene"].config["guide_narrator"].value
|
||||
|
||||
|
||||
@property
|
||||
def guide_scene_guidance_length(self) -> int:
|
||||
return int(self.actions["guide_scene"].config["guidance_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def guide_scene_cache_guidance(self) -> bool:
|
||||
return self.actions["guide_scene"].config["cache_guidance"].value
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").connect(
|
||||
self.on_summarization_scene_analysis_after
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").connect(
|
||||
self.on_summarization_scene_analysis_after
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
|
||||
self.on_editor_revision_analysis_before
|
||||
)
|
||||
|
||||
async def on_summarization_scene_analysis_after(self, emission: "SceneAnalysisEmission"):
|
||||
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.after"
|
||||
).connect(self.on_summarization_scene_analysis_after)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.cached"
|
||||
).connect(self.on_summarization_scene_analysis_after)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.editor.revision-analysis.before"
|
||||
).connect(self.on_editor_revision_analysis_before)
|
||||
|
||||
async def on_summarization_scene_analysis_after(
|
||||
self, emission: "SceneAnalysisEmission"
|
||||
):
|
||||
if not self.guide_scene:
|
||||
return
|
||||
|
||||
|
||||
guidance = None
|
||||
|
||||
|
||||
cached_guidance = await self.get_cached_guidance(emission.response)
|
||||
|
||||
|
||||
if emission.analysis_type == "narration" and self.guide_narrator:
|
||||
|
||||
if cached_guidance:
|
||||
guidance = cached_guidance
|
||||
else:
|
||||
@@ -167,15 +171,14 @@ class GuideSceneMixin:
|
||||
emission.response,
|
||||
response_length=self.guide_scene_guidance_length,
|
||||
)
|
||||
|
||||
|
||||
if not guidance:
|
||||
log.warning("director.guide_scene.narration: Empty resonse")
|
||||
return
|
||||
|
||||
|
||||
self.set_context_states(narrator_guidance=guidance)
|
||||
|
||||
elif emission.analysis_type == "conversation" and self.guide_actors:
|
||||
|
||||
|
||||
elif emission.analysis_type == "conversation" and self.guide_actors:
|
||||
if cached_guidance:
|
||||
guidance = cached_guidance
|
||||
else:
|
||||
@@ -184,94 +187,110 @@ class GuideSceneMixin:
|
||||
emission.template_vars.get("character"),
|
||||
response_length=self.guide_scene_guidance_length,
|
||||
)
|
||||
|
||||
|
||||
if not guidance:
|
||||
log.warning("director.guide_scene.conversation: Empty resonse")
|
||||
return
|
||||
|
||||
|
||||
self.set_context_states(actor_guidance=guidance)
|
||||
|
||||
if guidance:
|
||||
await self.set_cached_guidance(
|
||||
emission.response,
|
||||
guidance,
|
||||
emission.analysis_type,
|
||||
emission.template_vars.get("character")
|
||||
emission.response,
|
||||
guidance,
|
||||
emission.analysis_type,
|
||||
emission.template_vars.get("character"),
|
||||
)
|
||||
|
||||
|
||||
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
|
||||
cached_guidance = await self.get_cached_guidance(emission.response)
|
||||
if cached_guidance:
|
||||
emission.dynamic_instructions.append(DynamicInstruction(
|
||||
title="Guidance",
|
||||
content=cached_guidance
|
||||
))
|
||||
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(title="Guidance", content=cached_guidance)
|
||||
)
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def _cache_key(self) -> str:
|
||||
return f"cached_guidance"
|
||||
|
||||
async def get_cached_guidance(self, analysis:str | None = None) -> str | None:
|
||||
return "cached_guidance"
|
||||
|
||||
async def get_cached_guidance(self, analysis: str | None = None) -> str | None:
|
||||
"""
|
||||
Returns the cached guidance for the given analysis.
|
||||
|
||||
|
||||
If analysis is not provided, it will return the cached guidance for the last analysis regardless
|
||||
of the fingerprint.
|
||||
"""
|
||||
|
||||
|
||||
if not self.guide_scene_cache_guidance:
|
||||
return None
|
||||
|
||||
|
||||
key = self._cache_key()
|
||||
cached_guidance = self.get_scene_state(key)
|
||||
|
||||
|
||||
if cached_guidance:
|
||||
if not analysis:
|
||||
return cached_guidance.get("guidance")
|
||||
elif cached_guidance.get("fp") == self.context_fingerpint(extra=[analysis]):
|
||||
return cached_guidance.get("guidance")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def set_cached_guidance(self, analysis:str, guidance: str, analysis_type: str, character: "Character | None" = None):
|
||||
|
||||
async def set_cached_guidance(
|
||||
self,
|
||||
analysis: str,
|
||||
guidance: str,
|
||||
analysis_type: str,
|
||||
character: "Character | None" = None,
|
||||
):
|
||||
"""
|
||||
Sets the cached guidance for the given analysis.
|
||||
"""
|
||||
key = self._cache_key()
|
||||
self.set_scene_states(**{
|
||||
key: {
|
||||
"fp": self.context_fingerpint(extra=[analysis]),
|
||||
"guidance": guidance,
|
||||
"analysis_type": analysis_type,
|
||||
"character": character.name if character else None,
|
||||
self.set_scene_states(
|
||||
**{
|
||||
key: {
|
||||
"fp": self.context_fingerpint(extra=[analysis]),
|
||||
"guidance": guidance,
|
||||
"analysis_type": analysis_type,
|
||||
"character": character.name if character else None,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
async def get_cached_character_guidance(self, character_name: str) -> str | None:
|
||||
"""
|
||||
Returns the cached guidance for the given character.
|
||||
"""
|
||||
key = self._cache_key()
|
||||
cached_guidance = self.get_scene_state(key)
|
||||
|
||||
|
||||
if not cached_guidance:
|
||||
return None
|
||||
|
||||
if cached_guidance.get("character") == character_name and cached_guidance.get("analysis_type") == "conversation":
|
||||
|
||||
if (
|
||||
cached_guidance.get("character") == character_name
|
||||
and cached_guidance.get("analysis_type") == "conversation"
|
||||
):
|
||||
return cached_guidance.get("guidance")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
|
||||
@set_processing
|
||||
async def guide_actor_off_of_scene_analysis(self, analysis: str, character: "Character", response_length: int = 256):
|
||||
async def guide_actor_off_of_scene_analysis(
|
||||
self, analysis: str, character: "Character", response_length: int = 256
|
||||
):
|
||||
"""
|
||||
Guides the actor based on the scene analysis.
|
||||
"""
|
||||
|
||||
log.debug("director.guide_actor_off_of_scene_analysis", analysis=analysis, character=character)
|
||||
|
||||
log.debug(
|
||||
"director.guide_actor_off_of_scene_analysis",
|
||||
analysis=analysis,
|
||||
character=character,
|
||||
)
|
||||
response = await Prompt.request(
|
||||
"director.guide-conversation",
|
||||
self.client,
|
||||
@@ -285,19 +304,17 @@ class GuideSceneMixin:
|
||||
},
|
||||
)
|
||||
return strip_partial_sentences(response).strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def guide_narrator_off_of_scene_analysis(
|
||||
self,
|
||||
analysis: str,
|
||||
response_length: int = 256
|
||||
self, analysis: str, response_length: int = 256
|
||||
):
|
||||
"""
|
||||
Guides the narrator based on the scene analysis.
|
||||
"""
|
||||
|
||||
|
||||
log.debug("director.guide_narrator_off_of_scene_analysis", analysis=analysis)
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"director.guide-narration",
|
||||
self.client,
|
||||
@@ -309,4 +326,4 @@ class GuideSceneMixin:
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
return strip_partial_sentences(response).strip()
|
||||
return strip_partial_sentences(response).strip()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
@@ -9,24 +8,27 @@ from talemate.events import GameLoopActorIterEvent, SceneStateEvent
|
||||
|
||||
log = structlog.get_logger("talemate.agents.conversation.legacy_scene_instructions")
|
||||
|
||||
|
||||
class LegacySceneInstructionsMixin(
|
||||
GameInstructionsMixin,
|
||||
):
|
||||
"""
|
||||
Legacy support for scoped api instructions in scenes.
|
||||
|
||||
|
||||
This is being replaced by node based in structions, but kept for backwards compatibility.
|
||||
|
||||
|
||||
THIS WILL BE DEPRECATED IN THE FUTURE.
|
||||
"""
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.LSI_on_player_dialog)
|
||||
talemate.emit.async_signals.get("game_loop_actor_iter").connect(
|
||||
self.LSI_on_player_dialog
|
||||
)
|
||||
talemate.emit.async_signals.get("scene_init").connect(self.LSI_on_scene_init)
|
||||
|
||||
|
||||
async def LSI_on_scene_init(self, event: SceneStateEvent):
|
||||
"""
|
||||
LEGACY: If game state instructions specify to be run at the start of the game loop
|
||||
@@ -57,8 +59,10 @@ class LegacySceneInstructionsMixin(
|
||||
|
||||
if not event.actor.character.is_player:
|
||||
return
|
||||
|
||||
log.warning(f"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
|
||||
|
||||
log.warning(
|
||||
"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
|
||||
)
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug(
|
||||
@@ -77,17 +81,17 @@ class LegacySceneInstructionsMixin(
|
||||
not self.scene.npc_character_names
|
||||
or self.scene.game_state.ops.always_direct
|
||||
)
|
||||
|
||||
log.warning(f"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.", always_direct=always_direct)
|
||||
|
||||
log.warning(
|
||||
"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.",
|
||||
always_direct=always_direct,
|
||||
)
|
||||
|
||||
next_direct = self.next_direct_scene
|
||||
|
||||
TURNS = 5
|
||||
|
||||
if (
|
||||
next_direct % TURNS != 0
|
||||
or next_direct == 0
|
||||
):
|
||||
if next_direct % TURNS != 0 or next_direct == 0:
|
||||
if not always_direct:
|
||||
log.info("direct", skip=True, next_direct=next_direct)
|
||||
self.next_direct_scene += 1
|
||||
@@ -112,8 +116,10 @@ class LegacySceneInstructionsMixin(
|
||||
async def LSI_direct_scene(self):
|
||||
"""
|
||||
LEGACY: Direct the scene based scoped api scene instructions.
|
||||
This is being replaced by node based instructions, but kept for
|
||||
This is being replaced by node based instructions, but kept for
|
||||
backwards compatibility.
|
||||
"""
|
||||
log.warning(f"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.")
|
||||
await self.run_scene_instructions(self.scene)
|
||||
log.warning(
|
||||
"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future."
|
||||
)
|
||||
await self.run_scene_instructions(self.scene)
|
||||
|
||||
@@ -1,29 +1,34 @@
|
||||
import structlog
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from talemate.context import active_scene
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
|
||||
from typing import ClassVar
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
TYPE_CHOICES,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
|
||||
TYPE_CHOICES.extend([
|
||||
"director/direction",
|
||||
])
|
||||
TYPE_CHOICES.extend(
|
||||
[
|
||||
"director/direction",
|
||||
]
|
||||
)
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.director")
|
||||
|
||||
|
||||
@register("agents/director/Settings")
|
||||
class DirectorSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render director agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
def __init__(self, title="Director Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
@register("agents/director/PersistCharacter")
|
||||
class PersistCharacter(AgentNode):
|
||||
@@ -31,45 +36,44 @@ class PersistCharacter(AgentNode):
|
||||
Persists a character that currently only exists as part of the given context
|
||||
as a real character that can actively participate in the scene.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "director"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "director"
|
||||
|
||||
class Fields:
|
||||
determine_name = PropertyField(
|
||||
name="determine_name",
|
||||
type="bool",
|
||||
description="Whether to determine the name of the character",
|
||||
default=True
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Persist Character", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character_name", socket_type="str")
|
||||
self.add_input("context", socket_type="str", optional=True)
|
||||
self.add_input("attributes", socket_type="dict,str", optional=True)
|
||||
|
||||
|
||||
self.set_property("determine_name", True)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("character", socket_type="character")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
|
||||
character_name = self.get_input_value("character_name")
|
||||
context = self.normalized_input_value("context")
|
||||
attributes = self.normalized_input_value("attributes")
|
||||
determine_name = self.normalized_input_value("determine_name")
|
||||
|
||||
|
||||
character = await self.agent.persist_character(
|
||||
name=character_name,
|
||||
content=context,
|
||||
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()]) if attributes else None,
|
||||
determine_name=determine_name
|
||||
attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()])
|
||||
if attributes
|
||||
else None,
|
||||
determine_name=determine_name,
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"character": character
|
||||
})
|
||||
|
||||
self.set_output_values({"state": state, "character": character})
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import pydantic
|
||||
import asyncio
|
||||
import structlog
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from talemate.instance import get_agent
|
||||
@@ -18,27 +17,31 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.director")
|
||||
|
||||
|
||||
class InstructionPayload(pydantic.BaseModel):
|
||||
instructions:str = ""
|
||||
instructions: str = ""
|
||||
|
||||
|
||||
class SelectChoicePayload(pydantic.BaseModel):
|
||||
choice: str
|
||||
character:str = ""
|
||||
character: str = ""
|
||||
|
||||
|
||||
class CharacterPayload(InstructionPayload):
|
||||
character:str = ""
|
||||
character: str = ""
|
||||
|
||||
|
||||
class PersistCharacterPayload(pydantic.BaseModel):
|
||||
name: str
|
||||
templates: list[str] | None = None
|
||||
narrate_entry: bool = True
|
||||
narrate_entry_direction: str = ""
|
||||
|
||||
|
||||
active: bool = True
|
||||
determine_name: bool = True
|
||||
augment_attributes: str = ""
|
||||
generate_attributes: bool = True
|
||||
|
||||
|
||||
content: str = ""
|
||||
description: str = ""
|
||||
|
||||
@@ -47,13 +50,13 @@ class DirectorWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles director actions
|
||||
"""
|
||||
|
||||
|
||||
router = "director"
|
||||
|
||||
|
||||
@property
|
||||
def director(self):
|
||||
return get_agent("director")
|
||||
|
||||
|
||||
@set_loading("Generating dynamic actions", cancellable=True, as_async=True)
|
||||
async def handle_request_dynamic_choices(self, data: dict):
|
||||
"""
|
||||
@@ -61,21 +64,21 @@ class DirectorWebsocketHandler(Plugin):
|
||||
"""
|
||||
payload = CharacterPayload(**data)
|
||||
await self.director.generate_choices(**payload.model_dump())
|
||||
|
||||
|
||||
async def handle_select_choice(self, data: dict):
|
||||
payload = SelectChoicePayload(**data)
|
||||
|
||||
|
||||
log.debug("selecting choice", payload=payload)
|
||||
|
||||
|
||||
if payload.character:
|
||||
character = self.scene.get_character(payload.character)
|
||||
else:
|
||||
character = self.scene.get_player_character()
|
||||
|
||||
|
||||
if not character:
|
||||
log.error("handle_select_choice: could not find character", payload=payload)
|
||||
return
|
||||
|
||||
|
||||
# hijack the interaction state
|
||||
try:
|
||||
interaction_state = interaction.get()
|
||||
@@ -83,20 +86,23 @@ class DirectorWebsocketHandler(Plugin):
|
||||
# no interaction state
|
||||
log.error("handle_select_choice: no interaction state", payload=payload)
|
||||
return
|
||||
|
||||
|
||||
interaction_state.from_choice = payload.choice
|
||||
interaction_state.act_as = character.name if not character.is_player else None
|
||||
interaction_state.input = f"@{payload.choice}"
|
||||
|
||||
|
||||
async def handle_persist_character(self, data: dict):
|
||||
payload = PersistCharacterPayload(**data)
|
||||
scene: "Scene" = self.scene
|
||||
|
||||
|
||||
if not payload.content:
|
||||
payload.content = scene.snapshot(lines=15)
|
||||
|
||||
|
||||
# add as asyncio task
|
||||
task = asyncio.create_task(self.director.persist_character(**payload.model_dump()))
|
||||
task = asyncio.create_task(
|
||||
self.director.persist_character(**payload.model_dump())
|
||||
)
|
||||
|
||||
async def handle_task_done(task):
|
||||
if task.exception():
|
||||
log.error("Error persisting character", error=task.exception())
|
||||
@@ -112,4 +118,3 @@ class DirectorWebsocketHandler(Plugin):
|
||||
await self.signal_operation_done()
|
||||
|
||||
task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task)))
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class EditorAgent(
|
||||
agent_type = "editor"
|
||||
verbose_name = "Editor"
|
||||
websocket_handler = EditorWebsocketHandler
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls) -> dict[str, AgentAction]:
|
||||
actions = {
|
||||
@@ -56,9 +56,9 @@ class EditorAgent(
|
||||
description="The formatting to use for exposition.",
|
||||
value="novel",
|
||||
choices=[
|
||||
{"label": "Chat RP: \"Speech\" *narration*", "value": "chat"},
|
||||
{"label": "Novel: \"Speech\" narration", "value": "novel"},
|
||||
]
|
||||
{"label": 'Chat RP: "Speech" *narration*', "value": "chat"},
|
||||
{"label": 'Novel: "Speech" narration', "value": "novel"},
|
||||
],
|
||||
),
|
||||
"narrator": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -81,7 +81,7 @@ class EditorAgent(
|
||||
description="Attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
RevisionMixin.add_actions(actions)
|
||||
return actions
|
||||
@@ -90,7 +90,7 @@ class EditorAgent(
|
||||
self.client = client
|
||||
self.is_enabled = True
|
||||
self.actions = EditorAgent.init_actions()
|
||||
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.is_enabled
|
||||
@@ -102,11 +102,11 @@ class EditorAgent(
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def fix_exposition_enabled(self):
|
||||
return self.actions["fix_exposition"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def fix_exposition_formatting(self):
|
||||
return self.actions["fix_exposition"].config["formatting"].value
|
||||
@@ -114,11 +114,10 @@ class EditorAgent(
|
||||
@property
|
||||
def fix_exposition_narrator(self):
|
||||
return self.actions["fix_exposition"].config["narrator"].value
|
||||
|
||||
|
||||
@property
|
||||
def fix_exposition_user_input(self):
|
||||
return self.actions["fix_exposition"].config["user_input"].value
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
@@ -134,7 +133,7 @@ class EditorAgent(
|
||||
formatting = "md"
|
||||
else:
|
||||
formatting = None
|
||||
|
||||
|
||||
if self.fix_exposition_formatting == "chat":
|
||||
text = text.replace("**", "*")
|
||||
text = text.replace("[", "*").replace("]", "*")
|
||||
@@ -143,15 +142,14 @@ class EditorAgent(
|
||||
text = text.replace("*", "")
|
||||
text = text.replace("[", "").replace("]", "")
|
||||
text = text.replace("(", "").replace(")", "")
|
||||
|
||||
cleaned = util.ensure_dialog_format(
|
||||
text,
|
||||
talking_character=character.name if character else None,
|
||||
formatting=formatting
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
cleaned = util.ensure_dialog_format(
|
||||
text,
|
||||
talking_character=character.name if character else None,
|
||||
formatting=formatting,
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
async def on_conversation_generated(self, emission: ConversationAgentEmission):
|
||||
"""
|
||||
@@ -181,11 +179,13 @@ class EditorAgent(
|
||||
emission.response = edit
|
||||
|
||||
@set_processing
|
||||
async def cleanup_character_message(self, content: str, character: Character, force: bool = False):
|
||||
async def cleanup_character_message(
|
||||
self, content: str, character: Character, force: bool = False
|
||||
):
|
||||
"""
|
||||
Edits a text to make sure all narrative exposition and emotes is encased in *
|
||||
"""
|
||||
|
||||
|
||||
# if not content was generated, return it as is
|
||||
if not content:
|
||||
return content
|
||||
@@ -210,14 +210,14 @@ class EditorAgent(
|
||||
|
||||
content = util.clean_dialogue(content, main_name=character.name)
|
||||
content = util.strip_partial_sentences(content)
|
||||
|
||||
|
||||
# if there are uneven quotation marks, fix them by adding a closing quote
|
||||
if '"' in content and content.count('"') % 2 != 0:
|
||||
content += '"'
|
||||
|
||||
|
||||
if not self.fix_exposition_enabled and not exposition_fixed:
|
||||
return content
|
||||
|
||||
|
||||
content = self.fix_exposition_in_text(content, character)
|
||||
|
||||
return content
|
||||
@@ -225,7 +225,7 @@ class EditorAgent(
|
||||
@set_processing
|
||||
async def clean_up_narration(self, content: str, force: bool = False):
|
||||
content = util.strip_partial_sentences(content)
|
||||
if (self.fix_exposition_enabled and self.fix_exposition_narrator or force):
|
||||
if self.fix_exposition_enabled and self.fix_exposition_narrator or force:
|
||||
content = self.fix_exposition_in_text(content, None)
|
||||
if self.fix_exposition_formatting == "chat":
|
||||
if '"' not in content and "*" not in content:
|
||||
@@ -234,25 +234,27 @@ class EditorAgent(
|
||||
return content
|
||||
|
||||
@set_processing
|
||||
async def cleanup_user_input(self, text: str, as_narration: bool = False, force: bool = False):
|
||||
|
||||
async def cleanup_user_input(
|
||||
self, text: str, as_narration: bool = False, force: bool = False
|
||||
):
|
||||
# special prefix characters - when found, never edit
|
||||
PREFIX_CHARACTERS = ("!", "@", "/")
|
||||
if text.startswith(PREFIX_CHARACTERS):
|
||||
return text
|
||||
|
||||
if (not self.fix_exposition_user_input or not self.fix_exposition_enabled) and not force:
|
||||
|
||||
if (
|
||||
not self.fix_exposition_user_input or not self.fix_exposition_enabled
|
||||
) and not force:
|
||||
return text
|
||||
|
||||
|
||||
if not as_narration:
|
||||
if self.fix_exposition_formatting == "chat":
|
||||
if '"' not in text and "*" not in text:
|
||||
text = f'"{text}"'
|
||||
else:
|
||||
return await self.clean_up_narration(text)
|
||||
|
||||
|
||||
return self.fix_exposition_in_text(text)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def add_detail(self, content: str, character: Character):
|
||||
@@ -279,4 +281,4 @@ class EditorAgent(
|
||||
response = util.clean_dialogue(response, main_name=character.name)
|
||||
response = util.strip_partial_sentences(response)
|
||||
|
||||
return response
|
||||
return response
|
||||
|
||||
@@ -1,70 +1,37 @@
|
||||
import structlog
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from talemate.context import active_scene
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene
|
||||
from talemate.agents.editor import EditorAgent
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.editor")
|
||||
|
||||
|
||||
@register("agents/editor/Settings")
|
||||
class EditorSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render editor agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
def __init__(self, title="Editor Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
@register("agents/editor/CleanUpUserInput")
|
||||
class CleanUpUserInput(AgentNode):
|
||||
"""
|
||||
Cleans up user input.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
name="force",
|
||||
description="Force clean up",
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Clean Up User Input", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("user_input", socket_type="str")
|
||||
self.add_input("as_narration", socket_type="bool")
|
||||
|
||||
self.set_property("force", False)
|
||||
|
||||
self.add_output("cleaned_user_input", socket_type="str")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor:"EditorAgent" = self.agent
|
||||
user_input = self.get_input_value("user_input")
|
||||
force = self.get_property("force")
|
||||
as_narration = self.get_input_value("as_narration")
|
||||
cleaned_user_input = await editor.cleanup_user_input(user_input, as_narration=as_narration, force=force)
|
||||
self.set_output_values({
|
||||
"cleaned_user_input": cleaned_user_input,
|
||||
})
|
||||
|
||||
@register("agents/editor/CleanUpNarration")
|
||||
class CleanUpNarration(AgentNode):
|
||||
"""
|
||||
Cleans up narration.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
@@ -73,31 +40,41 @@ class CleanUpNarration(AgentNode):
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self, title="Clean Up Narration", **kwargs):
|
||||
|
||||
def __init__(self, title="Clean Up User Input", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("narration", socket_type="str")
|
||||
self.add_input("user_input", socket_type="str")
|
||||
self.add_input("as_narration", socket_type="bool")
|
||||
|
||||
self.set_property("force", False)
|
||||
self.add_output("cleaned_narration", socket_type="str")
|
||||
|
||||
|
||||
self.add_output("cleaned_user_input", socket_type="str")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor:"EditorAgent" = self.agent
|
||||
narration = self.get_input_value("narration")
|
||||
editor: "EditorAgent" = self.agent
|
||||
user_input = self.get_input_value("user_input")
|
||||
force = self.get_property("force")
|
||||
cleaned_narration = await editor.cleanup_narration(narration, force=force)
|
||||
self.set_output_values({
|
||||
"cleaned_narration": cleaned_narration,
|
||||
})
|
||||
|
||||
@register("agents/editor/CleanUoCharacterMessage")
|
||||
class CleanUpCharacterMessage(AgentNode):
|
||||
as_narration = self.get_input_value("as_narration")
|
||||
cleaned_user_input = await editor.cleanup_user_input(
|
||||
user_input, as_narration=as_narration, force=force
|
||||
)
|
||||
self.set_output_values(
|
||||
{
|
||||
"cleaned_user_input": cleaned_user_input,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@register("agents/editor/CleanUpNarration")
|
||||
class CleanUpNarration(AgentNode):
|
||||
"""
|
||||
Cleans up character message.
|
||||
Cleans up narration.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "editor"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
name="force",
|
||||
@@ -105,22 +82,62 @@ class CleanUpCharacterMessage(AgentNode):
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Clean Up Narration", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.add_input("narration", socket_type="str")
|
||||
self.set_property("force", False)
|
||||
self.add_output("cleaned_narration", socket_type="str")
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor: "EditorAgent" = self.agent
|
||||
narration = self.get_input_value("narration")
|
||||
force = self.get_property("force")
|
||||
cleaned_narration = await editor.cleanup_narration(narration, force=force)
|
||||
self.set_output_values(
|
||||
{
|
||||
"cleaned_narration": cleaned_narration,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@register("agents/editor/CleanUoCharacterMessage")
|
||||
class CleanUpCharacterMessage(AgentNode):
|
||||
"""
|
||||
Cleans up character message.
|
||||
"""
|
||||
|
||||
_agent_name: ClassVar[str] = "editor"
|
||||
|
||||
class Fields:
|
||||
force = PropertyField(
|
||||
name="force",
|
||||
description="Force clean up",
|
||||
type="bool",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def __init__(self, title="Clean Up Character Message", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("text", socket_type="str")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.set_property("force", False)
|
||||
self.add_output("cleaned_character_message", socket_type="str")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
editor:"EditorAgent" = self.agent
|
||||
editor: "EditorAgent" = self.agent
|
||||
text = self.get_input_value("text")
|
||||
force = self.get_property("force")
|
||||
character = self.get_input_value("character")
|
||||
cleaned_character_message = await editor.cleanup_character_message(text, character, force=force)
|
||||
self.set_output_values({
|
||||
"cleaned_character_message": cleaned_character_message,
|
||||
})
|
||||
cleaned_character_message = await editor.cleanup_character_message(
|
||||
text, character, force=force
|
||||
)
|
||||
self.set_output_values(
|
||||
{
|
||||
"cleaned_character_message": cleaned_character_message,
|
||||
}
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from talemate.instance import get_agent
|
||||
from talemate.server.websocket_plugin import Plugin
|
||||
from talemate.status import set_loading
|
||||
from talemate.scene_message import CharacterMessage
|
||||
from talemate.agents.editor.revision import RevisionContext, RevisionInformation
|
||||
|
||||
@@ -17,47 +16,49 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.editor")
|
||||
|
||||
|
||||
class RevisionPayload(pydantic.BaseModel):
|
||||
message_id: int
|
||||
|
||||
|
||||
class EditorWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles editor actions
|
||||
"""
|
||||
|
||||
|
||||
router = "editor"
|
||||
|
||||
|
||||
@property
|
||||
def editor(self):
|
||||
return get_agent("editor")
|
||||
|
||||
|
||||
async def handle_request_revision(self, data: dict):
|
||||
"""
|
||||
Generate clickable actions for the user
|
||||
"""
|
||||
|
||||
|
||||
editor = self.editor
|
||||
scene:"Scene" = self.scene
|
||||
|
||||
scene: "Scene" = self.scene
|
||||
|
||||
if not editor.revision_enabled:
|
||||
raise Exception("Revision is not enabled")
|
||||
|
||||
|
||||
payload = RevisionPayload(**data)
|
||||
message = scene.get_message(payload.message_id)
|
||||
|
||||
|
||||
character = None
|
||||
|
||||
|
||||
if isinstance(message, CharacterMessage):
|
||||
character = scene.get_character(message.character_name)
|
||||
|
||||
|
||||
if not message:
|
||||
raise Exception("Message not found")
|
||||
|
||||
|
||||
with RevisionContext(message.id):
|
||||
info = RevisionInformation(
|
||||
text=message.message,
|
||||
character=character,
|
||||
)
|
||||
revised = await editor.revision_revise(info)
|
||||
|
||||
scene.edit_message(message.id, revised)
|
||||
|
||||
scene.edit_message(message.id, revised)
|
||||
|
||||
@@ -75,10 +75,9 @@ class MemoryAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls, presets: list[dict] | None = None) -> dict[str, AgentAction]:
|
||||
|
||||
if presets is None:
|
||||
presets = []
|
||||
|
||||
|
||||
actions = {
|
||||
"_config": AgentAction(
|
||||
enabled=True,
|
||||
@@ -101,23 +100,25 @@ class MemoryAgent(Agent):
|
||||
choices=[
|
||||
{"value": "cpu", "label": "CPU"},
|
||||
{"value": "cuda", "label": "CUDA"},
|
||||
]
|
||||
],
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
return actions
|
||||
|
||||
|
||||
def __init__(self, scene, **kwargs):
|
||||
self.db = None
|
||||
self.scene = scene
|
||||
self.memory_tracker = {}
|
||||
self.config = load_config()
|
||||
self._ready_to_add = False
|
||||
|
||||
|
||||
handlers["config_saved"].connect(self.on_config_saved)
|
||||
async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available)
|
||||
|
||||
async_signals.get("client.embeddings_available").connect(
|
||||
self.on_client_embeddings_available
|
||||
)
|
||||
|
||||
self.actions = MemoryAgent.init_actions(presets=self.get_presets)
|
||||
|
||||
@property
|
||||
@@ -132,30 +133,32 @@ class MemoryAgent(Agent):
|
||||
@property
|
||||
def db_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@property
|
||||
def get_presets(self):
|
||||
|
||||
def _label(embedding:dict):
|
||||
prefix = embedding['client'] if embedding['client'] else embedding['embeddings']
|
||||
if embedding['model']:
|
||||
def _label(embedding: dict):
|
||||
prefix = (
|
||||
embedding["client"] if embedding["client"] else embedding["embeddings"]
|
||||
)
|
||||
if embedding["model"]:
|
||||
return f"{prefix}: {embedding['model']}"
|
||||
else:
|
||||
return f"{prefix}"
|
||||
|
||||
|
||||
return [
|
||||
{"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items()
|
||||
{"value": k, "label": _label(v)}
|
||||
for k, v in self.config.get("presets", {}).get("embeddings", {}).items()
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_config(self):
|
||||
_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
return self.config.get("presets", {}).get("embeddings", {}).get(_embeddings, {})
|
||||
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
return self.embeddings_config.get("embeddings", "sentence-transformer")
|
||||
|
||||
|
||||
@property
|
||||
def using_openai_embeddings(self):
|
||||
return self.embeddings == "openai"
|
||||
@@ -163,39 +166,34 @@ class MemoryAgent(Agent):
|
||||
@property
|
||||
def using_instructor_embeddings(self):
|
||||
return self.embeddings == "instructor"
|
||||
|
||||
|
||||
@property
|
||||
def using_sentence_transformer_embeddings(self):
|
||||
return self.embeddings == "default" or self.embeddings == "sentence-transformer"
|
||||
|
||||
|
||||
@property
|
||||
def using_client_api_embeddings(self):
|
||||
return self.embeddings == "client-api"
|
||||
|
||||
|
||||
@property
|
||||
def using_local_embeddings(self):
|
||||
return self.embeddings in [
|
||||
"instructor",
|
||||
"sentence-transformer",
|
||||
"default"
|
||||
]
|
||||
|
||||
|
||||
return self.embeddings in ["instructor", "sentence-transformer", "default"]
|
||||
|
||||
@property
|
||||
def embeddings_client(self):
|
||||
return self.embeddings_config.get("client")
|
||||
|
||||
|
||||
@property
|
||||
def max_distance(self) -> float:
|
||||
distance = float(self.embeddings_config.get("distance", 1.0))
|
||||
distance_mod = float(self.embeddings_config.get("distance_mod", 1.0))
|
||||
|
||||
|
||||
return distance * distance_mod
|
||||
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.embeddings_config.get("model")
|
||||
|
||||
|
||||
@property
|
||||
def distance_function(self):
|
||||
return self.embeddings_config.get("distance_function", "l2")
|
||||
@@ -213,25 +211,28 @@ class MemoryAgent(Agent):
|
||||
"""
|
||||
Returns a unique fingerprint for the current configuration
|
||||
"""
|
||||
|
||||
model_name = self.model.replace('/','-') if self.model else "none"
|
||||
|
||||
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
|
||||
|
||||
model_name = self.model.replace("/", "-") if self.model else "none"
|
||||
|
||||
return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower()
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
|
||||
_fingerprint = self.fingerprint
|
||||
|
||||
|
||||
await super().apply_config(*args, **kwargs)
|
||||
|
||||
|
||||
fingerprint_changed = _fingerprint != self.fingerprint
|
||||
|
||||
|
||||
# have embeddings or device changed?
|
||||
if fingerprint_changed:
|
||||
log.warning("memory agent", status="embedding function changed", old=_fingerprint, new=self.fingerprint)
|
||||
log.warning(
|
||||
"memory agent",
|
||||
status="embedding function changed",
|
||||
old=_fingerprint,
|
||||
new=self.fingerprint,
|
||||
)
|
||||
await self.handle_embeddings_change()
|
||||
|
||||
|
||||
|
||||
@set_processing
|
||||
async def handle_embeddings_change(self):
|
||||
scene = active_scene.get()
|
||||
@@ -239,10 +240,10 @@ class MemoryAgent(Agent):
|
||||
# if sentence-transformer and no model-name, set embeddings to default
|
||||
if self.using_sentence_transformer_embeddings and not self.model:
|
||||
self.actions["_config"].config["embeddings"].value = "default"
|
||||
|
||||
|
||||
if not scene or not scene.get_helper("memory"):
|
||||
return
|
||||
|
||||
|
||||
self.close_db(scene)
|
||||
emit("status", "Re-importing context database", status="busy")
|
||||
await scene.commit_to_memory()
|
||||
@@ -257,38 +258,49 @@ class MemoryAgent(Agent):
|
||||
def on_config_saved(self, event):
|
||||
loop = asyncio.get_running_loop()
|
||||
openai_key = self.openai_api_key
|
||||
|
||||
|
||||
fingerprint = self.fingerprint
|
||||
|
||||
|
||||
old_presets = self.actions["_config"].config["embeddings"].choices.copy()
|
||||
|
||||
|
||||
self.config = load_config()
|
||||
new_presets = self.sync_presets()
|
||||
if fingerprint != self.fingerprint:
|
||||
log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint)
|
||||
log.warning(
|
||||
"memory agent",
|
||||
status="embedding function changed",
|
||||
old=fingerprint,
|
||||
new=self.fingerprint,
|
||||
)
|
||||
loop.run_until_complete(self.handle_embeddings_change())
|
||||
|
||||
|
||||
emit_status = False
|
||||
|
||||
|
||||
if openai_key != self.openai_api_key:
|
||||
emit_status = True
|
||||
|
||||
|
||||
if old_presets != new_presets:
|
||||
emit_status = True
|
||||
|
||||
|
||||
if emit_status:
|
||||
loop.run_until_complete(self.emit_status())
|
||||
|
||||
|
||||
async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"):
|
||||
current_embeddings = self.actions["_config"].config["embeddings"].value
|
||||
|
||||
|
||||
if current_embeddings == event.client.embeddings_identifier:
|
||||
return
|
||||
|
||||
|
||||
if not self.using_client_api_embeddings or not self.ready:
|
||||
log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier)
|
||||
self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier
|
||||
log.warning(
|
||||
"memory agent - client embeddings available",
|
||||
status="changing embeddings",
|
||||
old=current_embeddings,
|
||||
new=event.client.embeddings_identifier,
|
||||
)
|
||||
self.actions["_config"].config[
|
||||
"embeddings"
|
||||
].value = event.client.embeddings_identifier
|
||||
await self.emit_status()
|
||||
await self.handle_embeddings_change()
|
||||
await self.save_config()
|
||||
@@ -301,13 +313,16 @@ class MemoryAgent(Agent):
|
||||
except EmbeddingsModelLoadError:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error("memory agent", error="failed to set db", details=traceback.format_exc())
|
||||
|
||||
if "torchvision::nms does not exist" in str(e):
|
||||
raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed")
|
||||
|
||||
raise SetDBError(str(e))
|
||||
log.error(
|
||||
"memory agent", error="failed to set db", details=traceback.format_exc()
|
||||
)
|
||||
|
||||
if "torchvision::nms does not exist" in str(e):
|
||||
raise SetDBError(
|
||||
"The embeddings you are trying to use require the `torchvision` package to be installed"
|
||||
)
|
||||
|
||||
raise SetDBError(str(e))
|
||||
|
||||
def close_db(self):
|
||||
raise NotImplementedError()
|
||||
@@ -426,9 +441,9 @@ class MemoryAgent(Agent):
|
||||
with MemoryRequest(query=text, query_params=query) as active_memory_request:
|
||||
active_memory_request.max_distance = self.max_distance
|
||||
return await asyncio.to_thread(self._get, text, character, **query)
|
||||
#return await loop.run_in_executor(
|
||||
# return await loop.run_in_executor(
|
||||
# None, functools.partial(self._get, text, character, **query)
|
||||
#)
|
||||
# )
|
||||
|
||||
def _get(self, text, character=None, **query):
|
||||
raise NotImplementedError()
|
||||
@@ -528,9 +543,7 @@ class MemoryAgent(Agent):
|
||||
continue
|
||||
|
||||
# Fetch potential memories for this query.
|
||||
raw_results = await self.get(
|
||||
formatter(query), limit=limit, **where
|
||||
)
|
||||
raw_results = await self.get(formatter(query), limit=limit, **where)
|
||||
|
||||
# Apply filter and respect the `iterate` limit for this query.
|
||||
accepted: list[str] = []
|
||||
@@ -591,9 +604,9 @@ class MemoryAgent(Agent):
|
||||
|
||||
Returns a dictionary with 'cosine_similarity' and 'euclidean_distance'.
|
||||
"""
|
||||
|
||||
|
||||
embed_fn = self.embedding_function
|
||||
|
||||
|
||||
# Embed the two strings
|
||||
vec1 = np.array(embed_fn([string1])[0])
|
||||
vec2 = np.array(embed_fn([string2])[0])
|
||||
@@ -604,17 +617,14 @@ class MemoryAgent(Agent):
|
||||
# Compute Euclidean distance
|
||||
euclidean_dist = np.linalg.norm(vec1 - vec2)
|
||||
|
||||
return {
|
||||
"cosine_similarity": cosine_sim,
|
||||
"euclidean_distance": euclidean_dist
|
||||
}
|
||||
return {"cosine_similarity": cosine_sim, "euclidean_distance": euclidean_dist}
|
||||
|
||||
async def compare_string_lists(
|
||||
self,
|
||||
list_a: list[str],
|
||||
list_b: list[str],
|
||||
similarity_threshold: float = None,
|
||||
distance_threshold: float = None
|
||||
distance_threshold: float = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Compare two lists of strings using the current embedding function without touching the database.
|
||||
@@ -625,15 +635,21 @@ class MemoryAgent(Agent):
|
||||
- 'similarity_matches': list of (i, j, score) (filtered if threshold set, otherwise all)
|
||||
- 'distance_matches': list of (i, j, distance) (filtered if threshold set, otherwise all)
|
||||
"""
|
||||
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
|
||||
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
|
||||
if (
|
||||
not self.db
|
||||
or not hasattr(self.db, "_embedding_function")
|
||||
or self.db._embedding_function is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Embedding function is not initialized. Make sure the database is set."
|
||||
)
|
||||
|
||||
if not list_a or not list_b:
|
||||
return {
|
||||
"cosine_similarity_matrix": np.array([]),
|
||||
"euclidean_distance_matrix": np.array([]),
|
||||
"similarity_matches": [],
|
||||
"distance_matches": []
|
||||
"distance_matches": [],
|
||||
}
|
||||
|
||||
embed_fn = self.db._embedding_function
|
||||
@@ -653,21 +669,29 @@ class MemoryAgent(Agent):
|
||||
cosine_similarity_matrix = np.dot(vecs_a_norm, vecs_b_norm.T)
|
||||
|
||||
# Euclidean distance matrix
|
||||
a_squared = np.sum(vecs_a ** 2, axis=1).reshape(-1, 1)
|
||||
b_squared = np.sum(vecs_b ** 2, axis=1).reshape(1, -1)
|
||||
euclidean_distance_matrix = np.sqrt(a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T))
|
||||
a_squared = np.sum(vecs_a**2, axis=1).reshape(-1, 1)
|
||||
b_squared = np.sum(vecs_b**2, axis=1).reshape(1, -1)
|
||||
euclidean_distance_matrix = np.sqrt(
|
||||
a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T)
|
||||
)
|
||||
|
||||
# Prepare matches
|
||||
similarity_matches = []
|
||||
distance_matches = []
|
||||
|
||||
# Populate similarity matches
|
||||
sim_indices = np.argwhere(cosine_similarity_matrix >= (similarity_threshold if similarity_threshold is not None else -np.inf))
|
||||
sim_indices = np.argwhere(
|
||||
cosine_similarity_matrix
|
||||
>= (similarity_threshold if similarity_threshold is not None else -np.inf)
|
||||
)
|
||||
for i, j in sim_indices:
|
||||
similarity_matches.append((i, j, cosine_similarity_matrix[i, j]))
|
||||
|
||||
# Populate distance matches
|
||||
dist_indices = np.argwhere(euclidean_distance_matrix <= (distance_threshold if distance_threshold is not None else np.inf))
|
||||
dist_indices = np.argwhere(
|
||||
euclidean_distance_matrix
|
||||
<= (distance_threshold if distance_threshold is not None else np.inf)
|
||||
)
|
||||
for i, j in dist_indices:
|
||||
distance_matches.append((i, j, euclidean_distance_matrix[i, j]))
|
||||
|
||||
@@ -675,9 +699,10 @@ class MemoryAgent(Agent):
|
||||
"cosine_similarity_matrix": cosine_similarity_matrix,
|
||||
"euclidean_distance_matrix": euclidean_distance_matrix,
|
||||
"similarity_matches": similarity_matches,
|
||||
"distance_matches": distance_matches
|
||||
"distance_matches": distance_matches,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@register(condition=lambda: chromadb is not None)
|
||||
class ChromaDBMemoryAgent(MemoryAgent):
|
||||
requires_llm_client = False
|
||||
@@ -690,32 +715,34 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
if getattr(self, "db_client", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def client_api_ready(self) -> bool:
|
||||
if self.using_client_api_embeddings:
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
embeddings_client: ClientBase | None = instance.get_client(
|
||||
self.embeddings_client
|
||||
)
|
||||
if not embeddings_client:
|
||||
return False
|
||||
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
return False
|
||||
|
||||
|
||||
if not embeddings_client.embeddings_status:
|
||||
return False
|
||||
|
||||
|
||||
if embeddings_client.current_status not in ["idle", "busy"]:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.using_client_api_embeddings and not self.client_api_ready:
|
||||
return "error"
|
||||
|
||||
|
||||
if self.ready:
|
||||
return "active" if not getattr(self, "processing", False) else "busy"
|
||||
|
||||
@@ -726,7 +753,6 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
|
||||
details = {
|
||||
"backend": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
@@ -738,23 +764,22 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
value=self.embeddings,
|
||||
description="The embeddings type.",
|
||||
).model_dump(),
|
||||
|
||||
}
|
||||
|
||||
|
||||
if self.model:
|
||||
details["model"] = AgentDetail(
|
||||
icon="mdi-brain",
|
||||
value=self.model,
|
||||
description="The embeddings model.",
|
||||
).model_dump()
|
||||
|
||||
|
||||
if self.embeddings_client:
|
||||
details["client"] = AgentDetail(
|
||||
icon="mdi-network-outline",
|
||||
value=self.embeddings_client,
|
||||
description="The client to use for embeddings.",
|
||||
).model_dump()
|
||||
|
||||
|
||||
if self.using_local_embeddings:
|
||||
details["device"] = AgentDetail(
|
||||
icon="mdi-memory",
|
||||
@@ -770,9 +795,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
"description": "You must provide an OpenAI API key to use OpenAI embeddings",
|
||||
"color": "error",
|
||||
}
|
||||
|
||||
|
||||
if self.using_client_api_embeddings:
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
embeddings_client: ClientBase | None = instance.get_client(
|
||||
self.embeddings_client
|
||||
)
|
||||
|
||||
if not embeddings_client:
|
||||
details["error"] = {
|
||||
@@ -784,7 +811,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
return details
|
||||
|
||||
client_name = embeddings_client.name
|
||||
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
error_message = f"{client_name} does not support embeddings"
|
||||
elif embeddings_client.current_status not in ["idle", "busy"]:
|
||||
@@ -793,7 +820,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
error_message = f"{client_name} has no embeddings model loaded"
|
||||
else:
|
||||
error_message = None
|
||||
|
||||
|
||||
if error_message:
|
||||
details["error"] = {
|
||||
"icon": "mdi-alert",
|
||||
@@ -814,8 +841,14 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
@property
|
||||
def embedding_function(self) -> Callable:
|
||||
if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None:
|
||||
raise RuntimeError("Embedding function is not initialized. Make sure the database is set.")
|
||||
if (
|
||||
not self.db
|
||||
or not hasattr(self.db, "_embedding_function")
|
||||
or self.db._embedding_function is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Embedding function is not initialized. Make sure the database is set."
|
||||
)
|
||||
|
||||
embed_fn = self.db._embedding_function
|
||||
return embed_fn
|
||||
@@ -823,18 +856,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
def make_collection_name(self, scene) -> str:
|
||||
# generate plain text collection name
|
||||
collection_name = f"{self.fingerprint}"
|
||||
|
||||
|
||||
# chromadb collection names have the following rules:
|
||||
# Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address
|
||||
|
||||
# Step 1: Hash the input string using MD5
|
||||
md5_hash = hashlib.md5(collection_name.encode()).hexdigest()
|
||||
|
||||
|
||||
# Step 2: Ensure the result is exactly 32 characters long
|
||||
hashed_collection_name = md5_hash[:32]
|
||||
|
||||
|
||||
return f"{scene.memory_id}-tm-{hashed_collection_name}"
|
||||
|
||||
|
||||
async def count(self):
|
||||
await asyncio.sleep(0)
|
||||
return self.db.count()
|
||||
@@ -853,14 +886,17 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
self.collection_name = collection_name = self.make_collection_name(self.scene)
|
||||
|
||||
log.info(
|
||||
"chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings
|
||||
"chromadb agent",
|
||||
status="setting up db",
|
||||
collection_name=collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
|
||||
distance_function = self.distance_function
|
||||
collection_metadata = {"hnsw:space": distance_function}
|
||||
device = self.actions["_config"].config["device"].value
|
||||
device = self.actions["_config"].config["device"].value
|
||||
model_name = self.model
|
||||
|
||||
|
||||
if self.using_openai_embeddings:
|
||||
if not openai_key:
|
||||
raise ValueError(
|
||||
@@ -878,7 +914,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
model_name=model_name,
|
||||
)
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=openai_ef, metadata=collection_metadata
|
||||
collection_name,
|
||||
embedding_function=openai_ef,
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
elif self.using_client_api_embeddings:
|
||||
log.info(
|
||||
@@ -886,20 +924,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
embeddings="Client API",
|
||||
client=self.embeddings_client,
|
||||
)
|
||||
|
||||
embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client)
|
||||
|
||||
embeddings_client: ClientBase | None = instance.get_client(
|
||||
self.embeddings_client
|
||||
)
|
||||
if not embeddings_client:
|
||||
raise ValueError(f"Client API embeddings client {self.embeddings_client} not found")
|
||||
|
||||
raise ValueError(
|
||||
f"Client API embeddings client {self.embeddings_client} not found"
|
||||
)
|
||||
|
||||
if not embeddings_client.supports_embeddings:
|
||||
raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings")
|
||||
|
||||
raise ValueError(
|
||||
f"Client API embeddings client {self.embeddings_client} does not support embeddings"
|
||||
)
|
||||
|
||||
ef = embeddings_client.embeddings_function
|
||||
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef, metadata=collection_metadata
|
||||
)
|
||||
|
||||
|
||||
elif self.using_instructor_embeddings:
|
||||
log.info(
|
||||
"chromadb",
|
||||
@@ -909,7 +953,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
)
|
||||
|
||||
ef = embedding_functions.InstructorEmbeddingFunction(
|
||||
model_name=model_name, device=device, instruction="Represent the document for retrieval:"
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
instruction="Represent the document for retrieval:",
|
||||
)
|
||||
|
||||
log.info("chromadb", status="embedding function ready")
|
||||
@@ -919,25 +965,26 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
log.info("chromadb", status="instructor db ready")
|
||||
else:
|
||||
log.info(
|
||||
"chromadb",
|
||||
embeddings="SentenceTransformer",
|
||||
"chromadb",
|
||||
embeddings="SentenceTransformer",
|
||||
model=model_name,
|
||||
device=device,
|
||||
distance_function=distance_function
|
||||
distance_function=distance_function,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
device=device
|
||||
device=device,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "`trust_remote_code=True` to remove this error" in str(e):
|
||||
raise EmbeddingsModelLoadError(model_name, "Model requires `Trust remote code` to be enabled")
|
||||
raise EmbeddingsModelLoadError(
|
||||
model_name, "Model requires `Trust remote code` to be enabled"
|
||||
)
|
||||
raise EmbeddingsModelLoadError(model_name, str(e))
|
||||
|
||||
|
||||
|
||||
self.db = self.db_client.get_or_create_collection(
|
||||
collection_name, embedding_function=ef, metadata=collection_metadata
|
||||
)
|
||||
@@ -989,9 +1036,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
try:
|
||||
self.db_client.delete_collection(collection_name)
|
||||
except chromadb.errors.NotFoundError as exc:
|
||||
log.error(
|
||||
"chromadb agent", error="collection not found", details=exc
|
||||
)
|
||||
log.error("chromadb agent", error="collection not found", details=exc)
|
||||
except ValueError as exc:
|
||||
log.error(
|
||||
"chromadb agent", error="failed to delete collection", details=exc
|
||||
@@ -1049,16 +1094,18 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
if not objects:
|
||||
return
|
||||
|
||||
|
||||
# track seen documents by id
|
||||
seen_ids = set()
|
||||
|
||||
for obj in objects:
|
||||
if obj["id"] in seen_ids:
|
||||
log.warning("chromadb agent", status="duplicate id discarded", id=obj["id"])
|
||||
log.warning(
|
||||
"chromadb agent", status="duplicate id discarded", id=obj["id"]
|
||||
)
|
||||
continue
|
||||
seen_ids.add(obj["id"])
|
||||
|
||||
|
||||
documents.append(obj["text"])
|
||||
meta = obj.get("meta", {})
|
||||
source = meta.get("source", "talemate")
|
||||
@@ -1071,7 +1118,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
metadatas.append(meta)
|
||||
uid = obj.get("id", f"{character}-{self.memory_tracker[character]}")
|
||||
ids.append(uid)
|
||||
|
||||
|
||||
self.db.upsert(documents=documents, metadatas=metadatas, ids=ids)
|
||||
|
||||
def _delete(self, meta: dict):
|
||||
@@ -1081,11 +1128,11 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
return
|
||||
|
||||
where = {"$and": [{k: v} for k, v in meta.items()]}
|
||||
|
||||
|
||||
# if there is only one item in $and reduce it to the key value pair
|
||||
if len(where["$and"]) == 1:
|
||||
where = where["$and"][0]
|
||||
|
||||
|
||||
self.db.delete(where=where)
|
||||
log.debug("chromadb agent delete", meta=meta, where=where)
|
||||
|
||||
@@ -1122,16 +1169,16 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
log.error("chromadb agent", error="failed to query", details=e)
|
||||
return []
|
||||
|
||||
#import json
|
||||
#print(json.dumps(_results["ids"], indent=2))
|
||||
#print(json.dumps(_results["distances"], indent=2))
|
||||
# import json
|
||||
# print(json.dumps(_results["ids"], indent=2))
|
||||
# print(json.dumps(_results["distances"], indent=2))
|
||||
|
||||
results = []
|
||||
|
||||
max_distance = self.max_distance
|
||||
|
||||
|
||||
closest = None
|
||||
|
||||
|
||||
active_memory_request = memory_request.get()
|
||||
|
||||
for i in range(len(_results["distances"][0])):
|
||||
@@ -1139,7 +1186,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
doc = _results["documents"][0][i]
|
||||
meta = _results["metadatas"][0][i]
|
||||
|
||||
|
||||
active_memory_request.add_result(doc, distance, meta)
|
||||
|
||||
if not meta:
|
||||
@@ -1150,7 +1197,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
# skip pin_only entries
|
||||
if meta.get("pin_only", False):
|
||||
continue
|
||||
|
||||
|
||||
if closest is None:
|
||||
closest = {"distance": distance, "doc": doc}
|
||||
elif distance < closest["distance"]:
|
||||
@@ -1172,13 +1219,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
|
||||
if len(results) > limit:
|
||||
break
|
||||
|
||||
|
||||
log.debug("chromadb agent get", closest=closest, max_distance=max_distance)
|
||||
self.last_query = {
|
||||
"query": text,
|
||||
"closest": closest,
|
||||
}
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def convert_ts_to_date_prefix(self, ts):
|
||||
@@ -1197,10 +1244,9 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||
return None
|
||||
|
||||
def _get_document(self, id) -> dict:
|
||||
|
||||
if not id:
|
||||
return {}
|
||||
|
||||
|
||||
result = self.db.get(ids=[id] if isinstance(id, str) else id)
|
||||
documents = {}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Context manager that collects and tracks memory agent requests
|
||||
Context manager that collects and tracks memory agent requests
|
||||
for profiling and debugging purposes
|
||||
"""
|
||||
|
||||
@@ -11,91 +11,118 @@ import time
|
||||
from talemate.emit import emit
|
||||
from talemate.agents.context import active_agent
|
||||
|
||||
__all__ = [
|
||||
"MemoryRequest",
|
||||
"start_memory_request"
|
||||
"MemoryRequestState"
|
||||
"memory_request"
|
||||
]
|
||||
__all__ = ["MemoryRequest"]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
DEBUG_MEMORY_REQUESTS = False
|
||||
|
||||
|
||||
class MemoryRequestResult(pydantic.BaseModel):
|
||||
doc: str
|
||||
distance: float
|
||||
meta: dict = pydantic.Field(default_factory=dict)
|
||||
|
||||
|
||||
|
||||
class MemoryRequestState(pydantic.BaseModel):
|
||||
query:str
|
||||
query: str
|
||||
results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
|
||||
accepted_results: list[MemoryRequestResult] = pydantic.Field(default_factory=list)
|
||||
query_params: dict = pydantic.Field(default_factory=dict)
|
||||
closest_distance: float | None = None
|
||||
furthest_distance: float | None = None
|
||||
max_distance: float | None = None
|
||||
|
||||
def add_result(self, doc:str, distance:float, meta:dict):
|
||||
|
||||
|
||||
def add_result(self, doc: str, distance: float, meta: dict):
|
||||
if doc is None:
|
||||
return
|
||||
|
||||
|
||||
self.results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
|
||||
self.closest_distance = min(self.closest_distance, distance) if self.closest_distance is not None else distance
|
||||
self.furthest_distance = max(self.furthest_distance, distance) if self.furthest_distance is not None else distance
|
||||
|
||||
def accept_result(self, doc:str, distance:float, meta:dict):
|
||||
|
||||
self.closest_distance = (
|
||||
min(self.closest_distance, distance)
|
||||
if self.closest_distance is not None
|
||||
else distance
|
||||
)
|
||||
self.furthest_distance = (
|
||||
max(self.furthest_distance, distance)
|
||||
if self.furthest_distance is not None
|
||||
else distance
|
||||
)
|
||||
|
||||
def accept_result(self, doc: str, distance: float, meta: dict):
|
||||
if doc is None:
|
||||
return
|
||||
|
||||
self.accepted_results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta))
|
||||
|
||||
|
||||
self.accepted_results.append(
|
||||
MemoryRequestResult(doc=doc, distance=distance, meta=meta)
|
||||
)
|
||||
|
||||
@property
|
||||
def closest_text(self):
|
||||
return str(self.results[0].doc) if self.results else None
|
||||
|
||||
|
||||
|
||||
memory_request = contextvars.ContextVar("memory_request", default=None)
|
||||
|
||||
|
||||
class MemoryRequest:
|
||||
|
||||
def __init__(self, query:str, query_params:dict=None):
|
||||
def __init__(self, query: str, query_params: dict = None):
|
||||
self.query = query
|
||||
self.query_params = query_params
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.state = MemoryRequestState(query=self.query, query_params=self.query_params)
|
||||
self.state = MemoryRequestState(
|
||||
query=self.query, query_params=self.query_params
|
||||
)
|
||||
self.token = memory_request.set(self.state)
|
||||
self.time_start = time.time()
|
||||
return self.state
|
||||
|
||||
|
||||
def __exit__(self, *args):
|
||||
|
||||
self.time_end = time.time()
|
||||
|
||||
|
||||
if DEBUG_MEMORY_REQUESTS:
|
||||
max_length = 50
|
||||
query = self.state.query[:max_length]+"..." if len(self.state.query) > max_length else self.state.query
|
||||
log.debug("MemoryRequest", number_of_results=len(self.state.results), query=query)
|
||||
log.debug("MemoryRequest", number_of_accepted_results=len(self.state.accepted_results), query=query)
|
||||
|
||||
query = (
|
||||
self.state.query[:max_length] + "..."
|
||||
if len(self.state.query) > max_length
|
||||
else self.state.query
|
||||
)
|
||||
log.debug(
|
||||
"MemoryRequest", number_of_results=len(self.state.results), query=query
|
||||
)
|
||||
log.debug(
|
||||
"MemoryRequest",
|
||||
number_of_accepted_results=len(self.state.accepted_results),
|
||||
query=query,
|
||||
)
|
||||
|
||||
for result in self.state.results:
|
||||
# distance to 2 decimal places
|
||||
log.debug("MemoryRequest RESULT", distance=f"{result.distance:.2f}", doc=result.doc[:max_length]+"...")
|
||||
|
||||
log.debug(
|
||||
"MemoryRequest RESULT",
|
||||
distance=f"{result.distance:.2f}",
|
||||
doc=result.doc[:max_length] + "...",
|
||||
)
|
||||
|
||||
agent_context = active_agent.get()
|
||||
|
||||
emit("memory_request", data=self.state.model_dump(), meta={
|
||||
"agent_stack": agent_context.agent_stack if agent_context else [],
|
||||
"agent_stack_uid": agent_context.agent_stack_uid if agent_context else None,
|
||||
"duration": self.time_end - self.time_start,
|
||||
}, websocket_passthrough=True)
|
||||
|
||||
|
||||
emit(
|
||||
"memory_request",
|
||||
data=self.state.model_dump(),
|
||||
meta={
|
||||
"agent_stack": agent_context.agent_stack if agent_context else [],
|
||||
"agent_stack_uid": agent_context.agent_stack_uid
|
||||
if agent_context
|
||||
else None,
|
||||
"duration": self.time_end - self.time_start,
|
||||
},
|
||||
websocket_passthrough=True,
|
||||
)
|
||||
|
||||
memory_request.reset(self.token)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# decorator that opens a memory request context
|
||||
async def start_memory_request(query):
|
||||
@@ -103,5 +130,7 @@ async def start_memory_request(query):
|
||||
async def wrapper(*args, **kwargs):
|
||||
with MemoryRequest(query):
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
__all__ = [
|
||||
'EmbeddingsModelLoadError',
|
||||
'MemoryAgentError',
|
||||
'SetDBError'
|
||||
]
|
||||
__all__ = ["EmbeddingsModelLoadError", "MemoryAgentError", "SetDBError"]
|
||||
|
||||
|
||||
class MemoryAgentError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SetDBError(OSError, MemoryAgentError):
|
||||
|
||||
def __init__(self, details:str):
|
||||
def __init__(self, details: str):
|
||||
super().__init__(f"Memory Agent - Failed to set up the database: {details}")
|
||||
|
||||
|
||||
class EmbeddingsModelLoadError(ValueError, MemoryAgentError):
|
||||
|
||||
def __init__(self, model_name:str, details:str):
|
||||
super().__init__(f"Memory Agent - Failed to load embeddings model {model_name}: {details}")
|
||||
def __init__(self, model_name: str, details: str):
|
||||
super().__init__(
|
||||
f"Memory Agent - Failed to load embeddings model {model_name}: {details}"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from talemate.agents.base import (
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
)
|
||||
from talemate.emit import emit
|
||||
import talemate.instance as instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,11 +13,10 @@ __all__ = ["MemoryRAGMixin"]
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class MemoryRAGMixin:
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
|
||||
actions["use_long_term_memory"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
@@ -44,7 +42,7 @@ class MemoryRAGMixin:
|
||||
{
|
||||
"label": "AI compiled question and answers (slow)",
|
||||
"value": "questions",
|
||||
}
|
||||
},
|
||||
],
|
||||
),
|
||||
"number_of_queries": AgentActionConfig(
|
||||
@@ -65,7 +63,7 @@ class MemoryRAGMixin:
|
||||
{"label": "Short (256)", "value": "256"},
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
]
|
||||
],
|
||||
),
|
||||
"cache": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -73,16 +71,16 @@ class MemoryRAGMixin:
|
||||
description="Cache the long term memory for faster retrieval.",
|
||||
note="This is a cross-agent cache, assuming they use the same options.",
|
||||
value=True,
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_enabled(self):
|
||||
return self.actions["use_long_term_memory"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_retrieval_method(self):
|
||||
return self.actions["use_long_term_memory"].config["retrieval_method"].value
|
||||
@@ -90,60 +88,60 @@ class MemoryRAGMixin:
|
||||
@property
|
||||
def long_term_memory_number_of_queries(self):
|
||||
return self.actions["use_long_term_memory"].config["number_of_queries"].value
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_answer_length(self):
|
||||
return int(self.actions["use_long_term_memory"].config["answer_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_cache(self):
|
||||
return self.actions["use_long_term_memory"].config["cache"].value
|
||||
|
||||
|
||||
@property
|
||||
def long_term_memory_cache_key(self):
|
||||
"""
|
||||
Build the key from the various options
|
||||
"""
|
||||
|
||||
|
||||
parts = [
|
||||
self.long_term_memory_retrieval_method,
|
||||
self.long_term_memory_number_of_queries,
|
||||
self.long_term_memory_answer_length
|
||||
self.long_term_memory_answer_length,
|
||||
]
|
||||
|
||||
|
||||
return "-".join(map(str, parts))
|
||||
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
|
||||
|
||||
# new scene, reset cache
|
||||
scene.rag_cache = {}
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
async def rag_set_cache(self, content:list[str]):
|
||||
|
||||
async def rag_set_cache(self, content: list[str]):
|
||||
self.scene.rag_cache[self.long_term_memory_cache_key] = {
|
||||
"content": content,
|
||||
"fingerprint": self.scene.history[-1].fingerprint if self.scene.history else 0
|
||||
"fingerprint": self.scene.history[-1].fingerprint
|
||||
if self.scene.history
|
||||
else 0,
|
||||
}
|
||||
|
||||
|
||||
async def rag_get_cache(self) -> list[str] | None:
|
||||
|
||||
if not self.long_term_memory_cache:
|
||||
return None
|
||||
|
||||
|
||||
fingerprint = self.scene.history[-1].fingerprint if self.scene.history else 0
|
||||
cache = self.scene.rag_cache.get(self.long_term_memory_cache_key)
|
||||
|
||||
|
||||
if cache and cache["fingerprint"] == fingerprint:
|
||||
return cache["content"]
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def rag_build(
|
||||
self,
|
||||
character: "Character | None" = None,
|
||||
self,
|
||||
character: "Character | None" = None,
|
||||
prompt: str = "",
|
||||
sub_instruction: str = "",
|
||||
) -> list[str]:
|
||||
@@ -153,37 +151,41 @@ class MemoryRAGMixin:
|
||||
|
||||
if not self.long_term_memory_enabled:
|
||||
return []
|
||||
|
||||
|
||||
cached = await self.rag_get_cache()
|
||||
|
||||
|
||||
if cached:
|
||||
log.debug(f"Using cached long term memory", agent=self.agent_type, key=self.long_term_memory_cache_key)
|
||||
log.debug(
|
||||
"Using cached long term memory",
|
||||
agent=self.agent_type,
|
||||
key=self.long_term_memory_cache_key,
|
||||
)
|
||||
return cached
|
||||
|
||||
memory_context = ""
|
||||
retrieval_method = self.long_term_memory_retrieval_method
|
||||
|
||||
|
||||
if not sub_instruction:
|
||||
if character:
|
||||
sub_instruction = f"continue the scene as {character.name}"
|
||||
elif hasattr(self, "rag_build_sub_instruction"):
|
||||
sub_instruction = await self.rag_build_sub_instruction()
|
||||
|
||||
|
||||
if not sub_instruction:
|
||||
sub_instruction = "continue the scene"
|
||||
|
||||
|
||||
if retrieval_method != "direct":
|
||||
world_state = instance.get_agent("world_state")
|
||||
|
||||
|
||||
if not prompt:
|
||||
prompt = self.scene.context_history(
|
||||
keep_director=False,
|
||||
budget=int(self.client.max_token_length * 0.75),
|
||||
)
|
||||
|
||||
|
||||
if isinstance(prompt, list):
|
||||
prompt = "\n".join(prompt)
|
||||
|
||||
|
||||
log.debug(
|
||||
"memory_rag_mixin.build_prompt_default_memory",
|
||||
direct=False,
|
||||
@@ -193,20 +195,21 @@ class MemoryRAGMixin:
|
||||
if retrieval_method == "questions":
|
||||
memory_context = (
|
||||
await world_state.analyze_text_and_extract_context(
|
||||
prompt, sub_instruction,
|
||||
prompt,
|
||||
sub_instruction,
|
||||
include_character_context=True,
|
||||
response_length=self.long_term_memory_answer_length,
|
||||
num_queries=self.long_term_memory_number_of_queries
|
||||
num_queries=self.long_term_memory_number_of_queries,
|
||||
)
|
||||
).split("\n")
|
||||
elif retrieval_method == "queries":
|
||||
memory_context = (
|
||||
await world_state.analyze_text_and_extract_context_via_queries(
|
||||
prompt, sub_instruction,
|
||||
prompt,
|
||||
sub_instruction,
|
||||
include_character_context=True,
|
||||
response_length=self.long_term_memory_answer_length,
|
||||
num_queries=self.long_term_memory_number_of_queries
|
||||
|
||||
num_queries=self.long_term_memory_number_of_queries,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -223,4 +226,4 @@ class MemoryRAGMixin:
|
||||
|
||||
await self.rag_set_cache(memory_context)
|
||||
|
||||
return memory_context
|
||||
return memory_context
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import random
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
@@ -50,15 +49,18 @@ log = structlog.get_logger("talemate.agents.narrator")
|
||||
class NarratorAgentEmission(AgentEmission):
|
||||
generation: list[str] = dataclasses.field(default_factory=list)
|
||||
response: str = dataclasses.field(default="")
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
talemate.emit.async_signals.register(
|
||||
"agent.narrator.before_generate",
|
||||
"agent.narrator.before_generate",
|
||||
"agent.narrator.inject_instructions",
|
||||
"agent.narrator.generated",
|
||||
)
|
||||
|
||||
|
||||
def set_processing(fn):
|
||||
"""
|
||||
Custom decorator that emits the agent status as processing while the function
|
||||
@@ -74,11 +76,15 @@ def set_processing(fn):
|
||||
if self.content_use_writing_style:
|
||||
self.set_context_states(writing_style=self.scene.writing_style)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(emission)
|
||||
await talemate.emit.async_signals.get("agent.narrator.inject_instructions").send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.narrator.before_generate").send(
|
||||
emission
|
||||
)
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.narrator.inject_instructions"
|
||||
).send(emission)
|
||||
|
||||
agent_context.state["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await fn(self, *args, **kwargs)
|
||||
emission.response = response
|
||||
await talemate.emit.async_signals.get("agent.narrator.generated").send(emission)
|
||||
@@ -88,10 +94,7 @@ def set_processing(fn):
|
||||
|
||||
|
||||
@register()
|
||||
class NarratorAgent(
|
||||
MemoryRAGMixin,
|
||||
Agent
|
||||
):
|
||||
class NarratorAgent(MemoryRAGMixin, Agent):
|
||||
"""
|
||||
Handles narration of the story
|
||||
"""
|
||||
@@ -99,7 +102,7 @@ class NarratorAgent(
|
||||
agent_type = "narrator"
|
||||
verbose_name = "Narrator"
|
||||
set_processing = set_processing
|
||||
|
||||
|
||||
websocket_handler = NarratorWebsocketHandler
|
||||
|
||||
@classmethod
|
||||
@@ -117,7 +120,7 @@ class NarratorAgent(
|
||||
min=32,
|
||||
max=1024,
|
||||
step=32,
|
||||
),
|
||||
),
|
||||
"instructions": AgentActionConfig(
|
||||
type="text",
|
||||
label="Instructions",
|
||||
@@ -154,7 +157,7 @@ class NarratorAgent(
|
||||
description="Use the writing style selected in the scene settings",
|
||||
value=True,
|
||||
),
|
||||
}
|
||||
},
|
||||
),
|
||||
"narrate_time_passage": AgentAction(
|
||||
enabled=True,
|
||||
@@ -201,7 +204,7 @@ class NarratorAgent(
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
MemoryRAGMixin.add_actions(actions)
|
||||
return actions
|
||||
|
||||
@@ -232,28 +235,33 @@ class NarratorAgent(
|
||||
if self.actions["generation_override"].enabled:
|
||||
return self.actions["generation_override"].config["length"].value
|
||||
return 128
|
||||
|
||||
|
||||
@property
|
||||
def narrate_time_passage_enabled(self) -> bool:
|
||||
return self.actions["narrate_time_passage"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def narrate_dialogue_enabled(self) -> bool:
|
||||
return self.actions["narrate_dialogue"].enabled
|
||||
|
||||
@property
|
||||
def narrate_dialogue_ai_chance(self) -> float:
|
||||
def narrate_dialogue_ai_chance(self) -> float:
|
||||
return self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||
|
||||
|
||||
@property
|
||||
def narrate_dialogue_player_chance(self) -> float:
|
||||
return self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||
|
||||
|
||||
@property
|
||||
def content_use_writing_style(self) -> bool:
|
||||
return self.actions["content"].config["use_writing_style"].value
|
||||
|
||||
def clean_result(self, result:str, ensure_dialog_format:bool=True, force_narrative:bool=True) -> str:
|
||||
|
||||
def clean_result(
|
||||
self,
|
||||
result: str,
|
||||
ensure_dialog_format: bool = True,
|
||||
force_narrative: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Cleans the result of a narration
|
||||
"""
|
||||
@@ -264,15 +272,14 @@ class NarratorAgent(
|
||||
|
||||
cleaned = []
|
||||
for line in result.split("\n"):
|
||||
|
||||
# skip lines that start with a #
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
|
||||
|
||||
log.debug("clean_result", line=line)
|
||||
|
||||
|
||||
character_dialogue_detected = False
|
||||
|
||||
|
||||
for character_name in character_names:
|
||||
if not character_name:
|
||||
continue
|
||||
@@ -280,25 +287,24 @@ class NarratorAgent(
|
||||
character_dialogue_detected = True
|
||||
elif line.startswith(f"{character_name.upper()}"):
|
||||
character_dialogue_detected = True
|
||||
|
||||
|
||||
if character_dialogue_detected:
|
||||
break
|
||||
|
||||
|
||||
if character_dialogue_detected:
|
||||
break
|
||||
|
||||
|
||||
cleaned.append(line)
|
||||
|
||||
result = "\n".join(cleaned)
|
||||
|
||||
|
||||
result = util.strip_partial_sentences(result)
|
||||
editor = get_agent("editor")
|
||||
|
||||
|
||||
if ensure_dialog_format or force_narrative:
|
||||
if editor.fix_exposition_enabled and editor.fix_exposition_narrator:
|
||||
result = editor.fix_exposition_in_text(result)
|
||||
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def connect(self, scene):
|
||||
@@ -324,16 +330,16 @@ class NarratorAgent(
|
||||
event.duration, event.human_duration, event.narrative
|
||||
)
|
||||
narrator_message = NarratorMessage(
|
||||
response,
|
||||
meta = {
|
||||
response,
|
||||
meta={
|
||||
"agent": "narrator",
|
||||
"function": "narrate_time_passage",
|
||||
"arguments": {
|
||||
"duration": event.duration,
|
||||
"time_passed": event.human_duration,
|
||||
"narrative_direction": event.narrative,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
@@ -345,7 +351,7 @@ class NarratorAgent(
|
||||
|
||||
if not self.narrate_dialogue_enabled:
|
||||
return
|
||||
|
||||
|
||||
if event.game_loop.had_passive_narration:
|
||||
log.debug(
|
||||
"narrate on dialog",
|
||||
@@ -375,14 +381,14 @@ class NarratorAgent(
|
||||
|
||||
response = await self.narrate_after_dialogue(event.actor.character)
|
||||
narrator_message = NarratorMessage(
|
||||
response,
|
||||
response,
|
||||
meta={
|
||||
"agent": "narrator",
|
||||
"function": "narrate_after_dialogue",
|
||||
"arguments": {
|
||||
"character": event.actor.character.name,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
emit("narrator", narrator_message)
|
||||
self.scene.push_history(narrator_message)
|
||||
@@ -390,7 +396,7 @@ class NarratorAgent(
|
||||
event.game_loop.had_passive_narration = True
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', visual_narration=True)
|
||||
@store_context_state("narrative_direction", visual_narration=True)
|
||||
async def narrate_scene(self, narrative_direction: str | None = None):
|
||||
"""
|
||||
Narrate the scene
|
||||
@@ -413,13 +419,13 @@ class NarratorAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction')
|
||||
@store_context_state("narrative_direction")
|
||||
async def progress_story(self, narrative_direction: str | None = None):
|
||||
"""
|
||||
Narrate scene progression, moving the plot forward.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- narrative_direction: A string describing the direction the narrative should take. If not provided, will attempt to subtly move the story forward.
|
||||
"""
|
||||
|
||||
@@ -431,10 +437,8 @@ class NarratorAgent(
|
||||
if narrative_direction is None:
|
||||
narrative_direction = "Slightly move the current scene forward."
|
||||
|
||||
log.debug(
|
||||
"narrative_direction", narrative_direction=narrative_direction
|
||||
)
|
||||
|
||||
log.debug("narrative_direction", narrative_direction=narrative_direction)
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.narrate-progress",
|
||||
self.client,
|
||||
@@ -453,13 +457,17 @@ class NarratorAgent(
|
||||
log.debug("progress_story", response=response)
|
||||
|
||||
response = self.clean_result(response.strip())
|
||||
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('query', query_narration=True)
|
||||
@store_context_state("query", query_narration=True)
|
||||
async def narrate_query(
|
||||
self, query: str, at_the_end: bool = False, as_narrative: bool = True, extra_context: str = None
|
||||
self,
|
||||
query: str,
|
||||
at_the_end: bool = False,
|
||||
as_narrative: bool = True,
|
||||
extra_context: str = None,
|
||||
):
|
||||
"""
|
||||
Narrate a specific query
|
||||
@@ -479,20 +487,20 @@ class NarratorAgent(
|
||||
},
|
||||
)
|
||||
response = self.clean_result(
|
||||
response.strip(),
|
||||
ensure_dialog_format=False,
|
||||
force_narrative=as_narrative
|
||||
response.strip(), ensure_dialog_format=False, force_narrative=as_narrative
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
@store_context_state('character', 'narrative_direction', visual_narration=True)
|
||||
async def narrate_character(self, character:"Character", narrative_direction: str = None):
|
||||
@store_context_state("character", "narrative_direction", visual_narration=True)
|
||||
async def narrate_character(
|
||||
self, character: "Character", narrative_direction: str = None
|
||||
):
|
||||
"""
|
||||
Narrate a specific character
|
||||
"""
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"narrator.narrate-character",
|
||||
self.client,
|
||||
@@ -506,12 +514,14 @@ class NarratorAgent(
|
||||
},
|
||||
)
|
||||
|
||||
response = self.clean_result(response.strip(), ensure_dialog_format=False, force_narrative=True)
|
||||
response = self.clean_result(
|
||||
response.strip(), ensure_dialog_format=False, force_narrative=True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', time_narration=True)
|
||||
@store_context_state("narrative_direction", time_narration=True)
|
||||
async def narrate_time_passage(
|
||||
self, duration: str, time_passed: str, narrative_direction: str
|
||||
):
|
||||
@@ -528,7 +538,7 @@ class NarratorAgent(
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"duration": duration,
|
||||
"time_passed": time_passed,
|
||||
"narrative": narrative_direction, # backwards compatibility
|
||||
"narrative": narrative_direction, # backwards compatibility
|
||||
"narrative_direction": narrative_direction,
|
||||
"extra_instructions": self.extra_instructions,
|
||||
},
|
||||
@@ -541,7 +551,7 @@ class NarratorAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', sensory_narration=True)
|
||||
@store_context_state("narrative_direction", sensory_narration=True)
|
||||
async def narrate_after_dialogue(
|
||||
self,
|
||||
character: Character,
|
||||
@@ -572,16 +582,16 @@ class NarratorAgent(
|
||||
async def narrate_environment(self, narrative_direction: str = None):
|
||||
"""
|
||||
Narrate the environment
|
||||
|
||||
|
||||
Wraps narrate_after_dialogue with the player character
|
||||
as the perspective character
|
||||
"""
|
||||
|
||||
|
||||
pc = self.scene.get_player_character()
|
||||
return await self.narrate_after_dialogue(pc, narrative_direction)
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', 'character')
|
||||
@store_context_state("narrative_direction", "character")
|
||||
async def narrate_character_entry(
|
||||
self, character: Character, narrative_direction: str = None
|
||||
):
|
||||
@@ -607,11 +617,9 @@ class NarratorAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
@store_context_state('narrative_direction', 'character')
|
||||
@store_context_state("narrative_direction", "character")
|
||||
async def narrate_character_exit(
|
||||
self,
|
||||
character: Character,
|
||||
narrative_direction: str = None
|
||||
self, character: Character, narrative_direction: str = None
|
||||
):
|
||||
"""
|
||||
Narrate a character exiting the scene
|
||||
@@ -679,8 +687,10 @@ class NarratorAgent(
|
||||
later on
|
||||
"""
|
||||
args = parameters.copy()
|
||||
|
||||
if args.get("character") and isinstance(args["character"], self.scene.Character):
|
||||
|
||||
if args.get("character") and isinstance(
|
||||
args["character"], self.scene.Character
|
||||
):
|
||||
args["character"] = args["character"].name
|
||||
|
||||
return {
|
||||
@@ -701,7 +711,9 @@ class NarratorAgent(
|
||||
fn = getattr(self, action_name)
|
||||
narration = await fn(**kwargs)
|
||||
|
||||
narrator_message = NarratorMessage(narration, meta=self.action_to_meta(action_name, kwargs))
|
||||
narrator_message = NarratorMessage(
|
||||
narration, meta=self.action_to_meta(action_name, kwargs)
|
||||
)
|
||||
self.scene.push_history(narrator_message)
|
||||
|
||||
if emit_message:
|
||||
@@ -720,20 +732,22 @@ class NarratorAgent(
|
||||
kind=kind,
|
||||
agent_function_name=agent_function_name,
|
||||
)
|
||||
|
||||
|
||||
# depending on conversation format in the context, stopping strings
|
||||
# for character names may change format
|
||||
conversation_agent = get_agent("conversation")
|
||||
|
||||
|
||||
if conversation_agent.conversation_format == "movie_script":
|
||||
character_names = [f"\n{c.name.upper()}\n" for c in self.scene.get_characters()]
|
||||
else:
|
||||
character_names = [
|
||||
f"\n{c.name.upper()}\n" for c in self.scene.get_characters()
|
||||
]
|
||||
else:
|
||||
character_names = [f"\n{c.name}:" for c in self.scene.get_characters()]
|
||||
|
||||
|
||||
if prompt_param.get("extra_stopping_strings") is None:
|
||||
prompt_param["extra_stopping_strings"] = []
|
||||
prompt_param["extra_stopping_strings"] += character_names
|
||||
|
||||
|
||||
self.set_generation_overrides(prompt_param)
|
||||
|
||||
def allow_repetition_break(
|
||||
@@ -748,7 +762,9 @@ class NarratorAgent(
|
||||
if not self.actions["generation_override"].enabled:
|
||||
return
|
||||
|
||||
prompt_param["max_tokens"] = min(prompt_param.get("max_tokens", 256), self.max_generation_length)
|
||||
prompt_param["max_tokens"] = min(
|
||||
prompt_param.get("max_tokens", 256), self.max_generation_length
|
||||
)
|
||||
|
||||
if self.jiggle > 0.0:
|
||||
nuke_repetition = client_context_attribute("nuke_repetition")
|
||||
@@ -756,4 +772,4 @@ class NarratorAgent(
|
||||
# we only apply the agent override if some other mechanism isn't already
|
||||
# setting the nuke_repetition value
|
||||
nuke_repetition = self.jiggle
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
set_client_context_attribute("nuke_repetition", nuke_repetition)
|
||||
|
||||
@@ -8,15 +8,16 @@ from talemate.util import iso8601_duration_to_human
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.narrator")
|
||||
|
||||
|
||||
class GenerateNarrationBase(AgentNode):
|
||||
"""
|
||||
Generate a narration message
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "narrator"
|
||||
_action_name:ClassVar[str] = ""
|
||||
_title:ClassVar[str] = "Generate Narration"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "narrator"
|
||||
_action_name: ClassVar[str] = ""
|
||||
_title: ClassVar[str] = "Generate Narration"
|
||||
|
||||
class Fields:
|
||||
narrative_direction = PropertyField(
|
||||
name="narrative_direction",
|
||||
@@ -24,151 +25,170 @@ class GenerateNarrationBase(AgentNode):
|
||||
default="",
|
||||
type="str",
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
if "title" not in kwargs:
|
||||
kwargs["title"] = self._title
|
||||
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("narrative_direction", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.add_output("generated", socket_type="str")
|
||||
self.add_output("message", socket_type="message_object")
|
||||
|
||||
|
||||
async def prepare_input_values(self) -> dict:
|
||||
input_values = self.get_input_values()
|
||||
input_values.pop("state", None)
|
||||
return input_values
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
input_values = await self.prepare_input_values()
|
||||
try:
|
||||
agent_fn = getattr(self.agent, self._action_name)
|
||||
except AttributeError:
|
||||
raise InputValueError(self, "_action_name", f"Agent does not have a function named {self._action_name}")
|
||||
|
||||
raise InputValueError(
|
||||
self,
|
||||
"_action_name",
|
||||
f"Agent does not have a function named {self._action_name}",
|
||||
)
|
||||
|
||||
narration = await agent_fn(**input_values)
|
||||
|
||||
|
||||
message = NarratorMessage(
|
||||
message=narration,
|
||||
meta=self.agent.action_to_meta(self._action_name, input_values),
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"generated": narration,
|
||||
"message": message
|
||||
})
|
||||
|
||||
|
||||
|
||||
self.set_output_values({"generated": narration, "message": message})
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateProgress")
|
||||
class GenerateProgressNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a progress narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "progress_story"
|
||||
_title:ClassVar[str] = "Generate Progress Narration"
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "progress_story"
|
||||
_title: ClassVar[str] = "Generate Progress Narration"
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateSceneNarration")
|
||||
class GenerateSceneNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a scene narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_scene"
|
||||
_title:ClassVar[str] = "Generate Scene Narration"
|
||||
"""
|
||||
|
||||
@register("agents/narrator/GenerateAfterDialogNarration")
|
||||
_action_name: ClassVar[str] = "narrate_scene"
|
||||
_title: ClassVar[str] = "Generate Scene Narration"
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateAfterDialogNarration")
|
||||
class GenerateAfterDialogNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate an after dialog narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_after_dialogue"
|
||||
_title:ClassVar[str] = "Generate After Dialog Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_after_dialogue"
|
||||
_title: ClassVar[str] = "Generate After Dialog Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateEnvironmentNarration")
|
||||
class GenerateEnvironmentNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate an environment narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_environment"
|
||||
_title:ClassVar[str] = "Generate Environment Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_environment"
|
||||
_title: ClassVar[str] = "Generate Environment Narration"
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateQueryNarration")
|
||||
class GenerateQueryNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a query narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_query"
|
||||
_title:ClassVar[str] = "Generate Query Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_query"
|
||||
_title: ClassVar[str] = "Generate Query Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("query", socket_type="str")
|
||||
self.add_input("extra_context", socket_type="str", optional=True)
|
||||
self.remove_input("narrative_direction")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateCharacterNarration")
|
||||
class GenerateCharacterNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a character narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_character"
|
||||
_title:ClassVar[str] = "Generate Character Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_character"
|
||||
_title: ClassVar[str] = "Generate Character Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateTimeNarration")
|
||||
class GenerateTimeNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a time narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_time_passage"
|
||||
_title:ClassVar[str] = "Generate Time Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_time_passage"
|
||||
_title: ClassVar[str] = "Generate Time Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("duration", socket_type="str")
|
||||
self.set_property("duration", "P0T1S")
|
||||
|
||||
|
||||
async def prepare_input_values(self) -> dict:
|
||||
input_values = await super().prepare_input_values()
|
||||
input_values["time_passed"] = iso8601_duration_to_human(input_values["duration"])
|
||||
input_values["time_passed"] = iso8601_duration_to_human(
|
||||
input_values["duration"]
|
||||
)
|
||||
return input_values
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateCharacterEntryNarration")
|
||||
class GenerateCharacterEntryNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a character entry narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_character_entry"
|
||||
_title:ClassVar[str] = "Generate Character Entry Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_character_entry"
|
||||
_title: ClassVar[str] = "Generate Character Entry Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/GenerateCharacterExitNarration")
|
||||
class GenerateCharacterExitNarration(GenerateNarrationBase):
|
||||
"""
|
||||
Generate a character exit narration message
|
||||
"""
|
||||
_action_name:ClassVar[str] = "narrate_character_exit"
|
||||
_title:ClassVar[str] = "Generate Character Exit Narration"
|
||||
|
||||
"""
|
||||
|
||||
_action_name: ClassVar[str] = "narrate_character_exit"
|
||||
_title: ClassVar[str] = "Generate Character Exit Narration"
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
|
||||
@register("agents/narrator/UnpackSource")
|
||||
class UnpackSource(AgentNode):
|
||||
"""
|
||||
@@ -176,25 +196,19 @@ class UnpackSource(AgentNode):
|
||||
into action name and arguments
|
||||
DEPRECATED
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "narrator"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "narrator"
|
||||
|
||||
def __init__(self, title="Unpack Source", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("source", socket_type="str")
|
||||
self.add_output("action_name", socket_type="str")
|
||||
self.add_output("arguments", socket_type="dict")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
source = self.get_input_value("source")
|
||||
action_name = ""
|
||||
arguments = {}
|
||||
|
||||
self.set_output_values({
|
||||
"action_name": action_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
|
||||
|
||||
self.set_output_values({"action_name": action_name, "arguments": arguments})
|
||||
|
||||
@@ -15,27 +15,31 @@ __all__ = [
|
||||
|
||||
log = structlog.get_logger("talemate.server.narrator")
|
||||
|
||||
|
||||
class QueryPayload(pydantic.BaseModel):
|
||||
query:str
|
||||
at_the_end:bool=True
|
||||
|
||||
query: str
|
||||
at_the_end: bool = True
|
||||
|
||||
|
||||
class NarrativeDirectionPayload(pydantic.BaseModel):
|
||||
narrative_direction:str = ""
|
||||
narrative_direction: str = ""
|
||||
|
||||
|
||||
class CharacterPayload(NarrativeDirectionPayload):
|
||||
character:str = ""
|
||||
character: str = ""
|
||||
|
||||
|
||||
class NarratorWebsocketHandler(Plugin):
|
||||
"""
|
||||
Handles narrator actions
|
||||
"""
|
||||
|
||||
|
||||
router = "narrator"
|
||||
|
||||
|
||||
@property
|
||||
def narrator(self):
|
||||
return get_agent("narrator")
|
||||
|
||||
|
||||
@set_loading("Progressing the story", cancellable=True, as_async=True)
|
||||
async def handle_progress(self, data: dict):
|
||||
"""
|
||||
@@ -47,7 +51,7 @@ class NarratorWebsocketHandler(Plugin):
|
||||
narrative_direction=payload.narrative_direction,
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
@set_loading("Narrating the environment", cancellable=True, as_async=True)
|
||||
async def handle_narrate_environment(self, data: dict):
|
||||
"""
|
||||
@@ -59,8 +63,7 @@ class NarratorWebsocketHandler(Plugin):
|
||||
narrative_direction=payload.narrative_direction,
|
||||
emit_message=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@set_loading("Working on a query", cancellable=True, as_async=True)
|
||||
async def handle_query(self, data: dict):
|
||||
"""
|
||||
@@ -68,56 +71,55 @@ class NarratorWebsocketHandler(Plugin):
|
||||
message.
|
||||
"""
|
||||
payload = QueryPayload(**data)
|
||||
|
||||
|
||||
narration = await self.narrator.narrate_query(**payload.model_dump())
|
||||
message: ContextInvestigationMessage = ContextInvestigationMessage(
|
||||
narration, sub_type="query"
|
||||
narration, sub_type="query"
|
||||
)
|
||||
message.set_source("narrator", "narrate_query", **payload.model_dump())
|
||||
|
||||
|
||||
|
||||
emit("context_investigation", message=message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
@set_loading("Looking at the scene", cancellable=True, as_async=True)
|
||||
async def handle_look_at_scene(self, data: dict):
|
||||
"""
|
||||
Look at the scene (optionally to a specific direction)
|
||||
|
||||
|
||||
This will result in a context investigation message.
|
||||
"""
|
||||
payload = NarrativeDirectionPayload(**data)
|
||||
|
||||
narration = await self.narrator.narrate_scene(narrative_direction=payload.narrative_direction)
|
||||
|
||||
|
||||
narration = await self.narrator.narrate_scene(
|
||||
narrative_direction=payload.narrative_direction
|
||||
)
|
||||
|
||||
message: ContextInvestigationMessage = ContextInvestigationMessage(
|
||||
narration, sub_type="visual-scene"
|
||||
)
|
||||
message.set_source("narrator", "narrate_scene", **payload.model_dump())
|
||||
|
||||
|
||||
emit("context_investigation", message=message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
|
||||
@set_loading("Looking at a character", cancellable=True, as_async=True)
|
||||
async def handle_look_at_character(self, data: dict):
|
||||
"""
|
||||
Look at a character (optionally to a specific direction)
|
||||
|
||||
|
||||
This will result in a context investigation message.
|
||||
"""
|
||||
payload = CharacterPayload(**data)
|
||||
|
||||
|
||||
|
||||
narration = await self.narrator.narrate_character(
|
||||
character = self.scene.get_character(payload.character),
|
||||
character=self.scene.get_character(payload.character),
|
||||
narrative_direction=payload.narrative_direction,
|
||||
)
|
||||
|
||||
|
||||
message: ContextInvestigationMessage = ContextInvestigationMessage(
|
||||
narration, sub_type="visual-character"
|
||||
)
|
||||
message.set_source("narrator", "narrate_character", **payload.model_dump())
|
||||
|
||||
|
||||
emit("context_investigation", message=message)
|
||||
self.scene.push_history(message)
|
||||
|
||||
@@ -24,4 +24,4 @@ def get_agent_class(name):
|
||||
|
||||
|
||||
def get_agent_types() -> list[str]:
|
||||
return list(AGENT_CLASSES.keys())
|
||||
return list(AGENT_CLASSES.keys())
|
||||
|
||||
@@ -7,26 +7,21 @@ import structlog
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
import talemate.emit.async_signals
|
||||
import talemate.util as util
|
||||
from talemate.emit import emit
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.scene_message import (
|
||||
DirectorMessage,
|
||||
TimePassageMessage,
|
||||
ContextInvestigationMessage,
|
||||
DirectorMessage,
|
||||
TimePassageMessage,
|
||||
ContextInvestigationMessage,
|
||||
ReinforcementMessage,
|
||||
)
|
||||
from talemate.world_state.templates import GenerationOptions
|
||||
from talemate.instance import get_agent
|
||||
from talemate.exceptions import GenerationCancelled
|
||||
import talemate.game.focal as focal
|
||||
import talemate.emit.async_signals
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
set_processing,
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
set_processing,
|
||||
AgentEmission,
|
||||
AgentTemplateEmission,
|
||||
RagBuildSubInstructionEmission,
|
||||
@@ -53,10 +48,12 @@ talemate.emit.async_signals.register(
|
||||
"agent.summarization.summarize.after",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BuildArchiveEmission(AgentEmission):
|
||||
generation_options: GenerationOptions | None = None
|
||||
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SummarizeEmission(AgentTemplateEmission):
|
||||
text: str = ""
|
||||
@@ -66,6 +63,7 @@ class SummarizeEmission(AgentTemplateEmission):
|
||||
summarization_history: list[str] | None = None
|
||||
summarization_type: Literal["dialogue", "events"] = "dialogue"
|
||||
|
||||
|
||||
@register()
|
||||
class SummarizeAgent(
|
||||
MemoryRAGMixin,
|
||||
@@ -73,7 +71,7 @@ class SummarizeAgent(
|
||||
ContextInvestigationMixin,
|
||||
# Needs to be after ContextInvestigationMixin so signals are connected in the right order
|
||||
SceneAnalyzationMixin,
|
||||
Agent
|
||||
Agent,
|
||||
):
|
||||
"""
|
||||
An agent that can be used to summarize text
|
||||
@@ -82,7 +80,7 @@ class SummarizeAgent(
|
||||
agent_type = "summarizer"
|
||||
verbose_name = "Summarizer"
|
||||
auto_squish = False
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_actions(cls) -> dict[str, AgentAction]:
|
||||
actions = {
|
||||
@@ -148,11 +146,11 @@ class SummarizeAgent(
|
||||
@property
|
||||
def archive_threshold(self):
|
||||
return self.actions["archive"].config["threshold"].value
|
||||
|
||||
|
||||
@property
|
||||
def archive_method(self):
|
||||
return self.actions["archive"].config["method"].value
|
||||
|
||||
|
||||
@property
|
||||
def archive_include_previous(self):
|
||||
return self.actions["archive"].config["include_previous"].value
|
||||
@@ -179,7 +177,7 @@ class SummarizeAgent(
|
||||
return result
|
||||
|
||||
# RAG HELPERS
|
||||
|
||||
|
||||
async def rag_build_sub_instruction(self):
|
||||
# Fire event to get the sub instruction from mixins
|
||||
emission = RagBuildSubInstructionEmission(
|
||||
@@ -188,37 +186,42 @@ class SummarizeAgent(
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.rag_build_sub_instruction"
|
||||
).send(emission)
|
||||
|
||||
|
||||
return emission.sub_instruction
|
||||
|
||||
|
||||
# SUMMARIZATION HELPERS
|
||||
|
||||
|
||||
async def previous_summaries(self, entry: ArchiveEntry) -> list[str]:
|
||||
|
||||
num_previous = self.archive_include_previous
|
||||
|
||||
|
||||
# find entry by .id
|
||||
entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None)
|
||||
entry_index = next(
|
||||
(
|
||||
i
|
||||
for i, e in enumerate(self.scene.archived_history)
|
||||
if e["id"] == entry.id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if entry_index is None:
|
||||
raise ValueError("Entry not found")
|
||||
end = entry_index - 1
|
||||
|
||||
previous_summaries = []
|
||||
|
||||
|
||||
if entry and num_previous > 0:
|
||||
if self.layered_history_available:
|
||||
previous_summaries = self.compile_layered_history(
|
||||
include_base_layer=True,
|
||||
base_layer_end_id=entry.id
|
||||
include_base_layer=True, base_layer_end_id=entry.id
|
||||
)[-num_previous:]
|
||||
else:
|
||||
previous_summaries = [
|
||||
entry.text for entry in self.scene.archived_history[end-num_previous:end]
|
||||
entry.text
|
||||
for entry in self.scene.archived_history[end - num_previous : end]
|
||||
]
|
||||
|
||||
return previous_summaries
|
||||
|
||||
|
||||
# SUMMARIZE
|
||||
|
||||
@set_processing
|
||||
@@ -226,13 +229,15 @@ class SummarizeAgent(
|
||||
self, scene, generation_options: GenerationOptions | None = None
|
||||
):
|
||||
end = None
|
||||
|
||||
|
||||
emission = BuildArchiveEmission(
|
||||
agent=self,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.before_build_archive").send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.before_build_archive"
|
||||
).send(emission)
|
||||
|
||||
if not self.actions["archive"].enabled:
|
||||
return
|
||||
@@ -260,7 +265,7 @@ class SummarizeAgent(
|
||||
extra_context = [
|
||||
entry["text"] for entry in scene.archived_history[-num_previous:]
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
extra_context = None
|
||||
|
||||
@@ -283,7 +288,10 @@ class SummarizeAgent(
|
||||
|
||||
# log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...")
|
||||
|
||||
if isinstance(dialogue, (DirectorMessage, ContextInvestigationMessage, ReinforcementMessage)):
|
||||
if isinstance(
|
||||
dialogue,
|
||||
(DirectorMessage, ContextInvestigationMessage, ReinforcementMessage),
|
||||
):
|
||||
# these messages are not part of the dialogue and should not be summarized
|
||||
if i == start:
|
||||
start += 1
|
||||
@@ -292,7 +300,7 @@ class SummarizeAgent(
|
||||
if isinstance(dialogue, TimePassageMessage):
|
||||
log.debug("build_archive", time_passage_message=dialogue)
|
||||
ts = util.iso8601_add(ts, dialogue.ts)
|
||||
|
||||
|
||||
if i == start:
|
||||
log.debug(
|
||||
"build_archive",
|
||||
@@ -347,23 +355,28 @@ class SummarizeAgent(
|
||||
if str(line) in terminating_line:
|
||||
break
|
||||
adjusted_dialogue.append(line)
|
||||
|
||||
|
||||
# if difference start and end is less than 4, ignore the termination
|
||||
if len(adjusted_dialogue) > 4:
|
||||
dialogue_entries = adjusted_dialogue
|
||||
end = start + len(dialogue_entries) - 1
|
||||
else:
|
||||
log.debug("build_archive", message="Ignoring termination", start=start, end=end, adjusted_dialogue=adjusted_dialogue)
|
||||
log.debug(
|
||||
"build_archive",
|
||||
message="Ignoring termination",
|
||||
start=start,
|
||||
end=end,
|
||||
adjusted_dialogue=adjusted_dialogue,
|
||||
)
|
||||
|
||||
if dialogue_entries:
|
||||
|
||||
if not extra_context:
|
||||
# prepend scene intro to dialogue
|
||||
dialogue_entries.insert(0, scene.intro)
|
||||
|
||||
|
||||
summarized = None
|
||||
retries = 5
|
||||
|
||||
|
||||
while not summarized and retries > 0:
|
||||
summarized = await self.summarize(
|
||||
"\n".join(map(str, dialogue_entries)),
|
||||
@@ -371,7 +384,7 @@ class SummarizeAgent(
|
||||
generation_options=generation_options,
|
||||
)
|
||||
retries -= 1
|
||||
|
||||
|
||||
if not summarized:
|
||||
raise IOError("Failed to summarize dialogue", dialogue=dialogue_entries)
|
||||
|
||||
@@ -382,13 +395,17 @@ class SummarizeAgent(
|
||||
|
||||
# determine the appropariate timestamp for the summarization
|
||||
|
||||
await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts))
|
||||
|
||||
scene.ts=ts
|
||||
await scene.push_archive(
|
||||
ArchiveEntry(text=summarized, start=start, end=end, ts=ts)
|
||||
)
|
||||
|
||||
scene.ts = ts
|
||||
scene.emit_status()
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.after_build_archive").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.after_build_archive"
|
||||
).send(emission)
|
||||
|
||||
return True
|
||||
|
||||
@set_processing
|
||||
@@ -408,23 +425,23 @@ class SummarizeAgent(
|
||||
return response
|
||||
|
||||
@set_processing
|
||||
async def find_natural_scene_termination(self, event_chunks:list[str]) -> list[list[str]]:
|
||||
async def find_natural_scene_termination(
|
||||
self, event_chunks: list[str]
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Will analyze a list of events and return a list of events that
|
||||
has been separated at a natural scene termination points.
|
||||
"""
|
||||
|
||||
|
||||
# scan through event chunks and split into paragraphs
|
||||
rebuilt_chunks = []
|
||||
|
||||
|
||||
for chunk in event_chunks:
|
||||
paragraphs = [
|
||||
p.strip() for p in chunk.split("\n") if p.strip()
|
||||
]
|
||||
paragraphs = [p.strip() for p in chunk.split("\n") if p.strip()]
|
||||
rebuilt_chunks.extend(paragraphs)
|
||||
|
||||
|
||||
event_chunks = rebuilt_chunks
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"summarizer.find-natural-scene-termination-events",
|
||||
self.client,
|
||||
@@ -436,38 +453,42 @@ class SummarizeAgent(
|
||||
},
|
||||
)
|
||||
response = response.strip()
|
||||
|
||||
|
||||
items = util.extract_list(response)
|
||||
|
||||
# will be a list of
|
||||
|
||||
# will be a list of
|
||||
# ["Progress 1", "Progress 12", "Progress 323", ...]
|
||||
# convert to a list of just numbers
|
||||
|
||||
# convert to a list of just numbers
|
||||
|
||||
numbers = []
|
||||
|
||||
|
||||
for item in items:
|
||||
match = re.match(r"Progress (\d+)", item.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
|
||||
# make sure its unique and sorted
|
||||
numbers = sorted(list(set(numbers)))
|
||||
|
||||
|
||||
result = []
|
||||
prev_number = 0
|
||||
for number in numbers:
|
||||
result.append(event_chunks[prev_number:number+1])
|
||||
prev_number = number+1
|
||||
|
||||
#result = {
|
||||
result.append(event_chunks[prev_number : number + 1])
|
||||
prev_number = number + 1
|
||||
|
||||
# result = {
|
||||
# "selected": event_chunks[:number+1],
|
||||
# "remaining": event_chunks[number+1:]
|
||||
#}
|
||||
|
||||
log.debug("find_natural_scene_termination", response=response, result=result, numbers=numbers)
|
||||
|
||||
# }
|
||||
|
||||
log.debug(
|
||||
"find_natural_scene_termination",
|
||||
response=response,
|
||||
result=result,
|
||||
numbers=numbers,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize(
|
||||
@@ -481,9 +502,9 @@ class SummarizeAgent(
|
||||
"""
|
||||
Summarize the given text
|
||||
"""
|
||||
|
||||
|
||||
response_length = 1024
|
||||
|
||||
|
||||
template_vars = {
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
@@ -500,7 +521,7 @@ class SummarizeAgent(
|
||||
"analyze_chunks": self.layered_history_analyze_chunks,
|
||||
"response_length": response_length,
|
||||
}
|
||||
|
||||
|
||||
emission = SummarizeEmission(
|
||||
agent=self,
|
||||
text=text,
|
||||
@@ -511,43 +532,46 @@ class SummarizeAgent(
|
||||
summarization_history=extra_context or [],
|
||||
summarization_type="dialogue",
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.before"
|
||||
).send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
f"summarizer.summarize-dialogue",
|
||||
"summarizer.summarize-dialogue",
|
||||
self.client,
|
||||
f"summarize_{response_length}",
|
||||
vars=template_vars,
|
||||
dedupe_enabled=False
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"summarize", dialogue_length=len(text), summarized_length=len(response)
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
summary = response.split("SUMMARY:")[1].strip()
|
||||
except Exception as e:
|
||||
log.error("summarize failed", response=response, exc=e)
|
||||
return ""
|
||||
|
||||
|
||||
# capitalize first letter
|
||||
try:
|
||||
summary = summary[0].upper() + summary[1:]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
emission.response = self.clean_result(summary)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.after"
|
||||
).send(emission)
|
||||
|
||||
summary = emission.response
|
||||
|
||||
|
||||
return self.clean_result(summary)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize_events(
|
||||
@@ -563,15 +587,14 @@ class SummarizeAgent(
|
||||
"""
|
||||
Summarize the given text
|
||||
"""
|
||||
|
||||
|
||||
if not extra_context:
|
||||
extra_context = ""
|
||||
|
||||
|
||||
mentioned_characters: list["Character"] = self.scene.parse_characters_from_text(
|
||||
text + extra_context,
|
||||
exclude_active=True
|
||||
text + extra_context, exclude_active=True
|
||||
)
|
||||
|
||||
|
||||
template_vars = {
|
||||
"dialogue": text,
|
||||
"scene": self.scene,
|
||||
@@ -585,7 +608,7 @@ class SummarizeAgent(
|
||||
"response_length": response_length,
|
||||
"mentioned_characters": mentioned_characters,
|
||||
}
|
||||
|
||||
|
||||
emission = SummarizeEmission(
|
||||
agent=self,
|
||||
text=text,
|
||||
@@ -596,46 +619,62 @@ class SummarizeAgent(
|
||||
summarization_history=[extra_context] if extra_context else [],
|
||||
summarization_type="events",
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.before"
|
||||
).send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
f"summarizer.summarize-events",
|
||||
"summarizer.summarize-events",
|
||||
self.client,
|
||||
f"summarize_{response_length}",
|
||||
vars=template_vars,
|
||||
dedupe_enabled=False
|
||||
dedupe_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
response = response.strip()
|
||||
response = response.replace('"', "")
|
||||
|
||||
|
||||
log.debug(
|
||||
"layered_history_summarize", original_length=len(text), summarized_length=len(response)
|
||||
"layered_history_summarize",
|
||||
original_length=len(text),
|
||||
summarized_length=len(response),
|
||||
)
|
||||
|
||||
|
||||
# clean up analyzation (remove analyzation text)
|
||||
if self.layered_history_analyze_chunks:
|
||||
# remove all lines that begin with "ANALYSIS OF CHUNK \d+:"
|
||||
response = "\n".join([line for line in response.split("\n") if not line.startswith("ANALYSIS OF CHUNK")])
|
||||
|
||||
response = "\n".join(
|
||||
[
|
||||
line
|
||||
for line in response.split("\n")
|
||||
if not line.startswith("ANALYSIS OF CHUNK")
|
||||
]
|
||||
)
|
||||
|
||||
# strip all occurences of "CHUNK \d+: " from the summary
|
||||
response = re.sub(r"(CHUNK|CHAPTER) \d+:\s+", "", response)
|
||||
|
||||
|
||||
# capitalize first letter
|
||||
try:
|
||||
response = response[0].upper() + response[1:]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
emission.response = self.clean_result(response)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.summarize.after"
|
||||
).send(emission)
|
||||
|
||||
response = emission.response
|
||||
|
||||
log.debug("summarize_events", original_length=len(text), summarized_length=len(response))
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarize_events",
|
||||
original_length=len(text),
|
||||
summarized_length=len(response),
|
||||
)
|
||||
|
||||
return self.clean_result(response)
|
||||
|
||||
@@ -28,13 +28,15 @@ talemate.emit.async_signals.register(
|
||||
"agent.summarization.scene_analysis.after",
|
||||
"agent.summarization.scene_analysis.cached",
|
||||
"agent.summarization.scene_analysis.before_deep_analysis",
|
||||
"agent.summarization.scene_analysis.after_deep_analysis",
|
||||
"agent.summarization.scene_analysis.after_deep_analysis",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SceneAnalysisEmission(AgentTemplateEmission):
|
||||
analysis_type: str | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SceneAnalysisDeepAnalysisEmission(AgentEmission):
|
||||
analysis: str
|
||||
@@ -42,15 +44,16 @@ class SceneAnalysisDeepAnalysisEmission(AgentEmission):
|
||||
analysis_sub_type: str | None = None
|
||||
max_content_investigations: int = 1
|
||||
character: "Character" = None
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list)
|
||||
dynamic_instructions: list[DynamicInstruction] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
class SceneAnalyzationMixin:
|
||||
|
||||
"""
|
||||
Summarizer agent mixin that provides functionality for scene analyzation.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["analyze_scene"] = AgentAction(
|
||||
@@ -71,8 +74,8 @@ class SceneAnalyzationMixin:
|
||||
choices=[
|
||||
{"label": "Short (256)", "value": "256"},
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"}
|
||||
]
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
],
|
||||
),
|
||||
"for_conversation": AgentActionConfig(
|
||||
type="bool",
|
||||
@@ -109,196 +112,204 @@ class SceneAnalyzationMixin:
|
||||
value=True,
|
||||
quick_toggle=True,
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def analyze_scene(self) -> bool:
|
||||
return self.actions["analyze_scene"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def analysis_length(self) -> int:
|
||||
return int(self.actions["analyze_scene"].config["analysis_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def cache_analysis(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["cache_analysis"].value
|
||||
|
||||
|
||||
@property
|
||||
def deep_analysis(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["deep_analysis"].value
|
||||
|
||||
|
||||
@property
|
||||
def deep_analysis_max_context_investigations(self) -> int:
|
||||
return self.actions["analyze_scene"].config["deep_analysis_max_context_investigations"].value
|
||||
|
||||
return (
|
||||
self.actions["analyze_scene"]
|
||||
.config["deep_analysis_max_context_investigations"]
|
||||
.value
|
||||
)
|
||||
|
||||
@property
|
||||
def analyze_scene_for_conversation(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["for_conversation"].value
|
||||
|
||||
|
||||
@property
|
||||
def analyze_scene_for_narration(self) -> bool:
|
||||
return self.actions["analyze_scene"].config["for_narration"].value
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
|
||||
self.on_inject_instructions
|
||||
)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.conversation.inject_instructions"
|
||||
).connect(self.on_inject_instructions)
|
||||
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
|
||||
self.on_inject_instructions
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.summarization.rag_build_sub_instruction").connect(
|
||||
self.on_rag_build_sub_instruction
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect(
|
||||
self.on_editor_revision_analysis_before
|
||||
)
|
||||
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.rag_build_sub_instruction"
|
||||
).connect(self.on_rag_build_sub_instruction)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.editor.revision-analysis.before"
|
||||
).connect(self.on_editor_revision_analysis_before)
|
||||
|
||||
async def on_inject_instructions(
|
||||
self,
|
||||
emission:ConversationAgentEmission | NarratorAgentEmission,
|
||||
self,
|
||||
emission: ConversationAgentEmission | NarratorAgentEmission,
|
||||
):
|
||||
"""
|
||||
Injects instructions into the conversation.
|
||||
"""
|
||||
|
||||
|
||||
if isinstance(emission, ConversationAgentEmission):
|
||||
emission_type = "conversation"
|
||||
elif isinstance(emission, NarratorAgentEmission):
|
||||
emission_type = "narration"
|
||||
else:
|
||||
raise ValueError("Invalid emission type.")
|
||||
|
||||
|
||||
if not self.analyze_scene:
|
||||
return
|
||||
|
||||
|
||||
analyze_scene_for_type = getattr(self, f"analyze_scene_for_{emission_type}")
|
||||
|
||||
|
||||
if not analyze_scene_for_type:
|
||||
return
|
||||
|
||||
|
||||
analysis = None
|
||||
|
||||
|
||||
# self.set_scene_states and self.get_scene_state to store
|
||||
# cached analysis in scene states
|
||||
|
||||
|
||||
if self.cache_analysis:
|
||||
analysis = await self.get_cached_analysis(emission_type)
|
||||
if analysis:
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").send(
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.cached"
|
||||
).send(
|
||||
SceneAnalysisEmission(
|
||||
agent=self,
|
||||
analysis_type=emission_type,
|
||||
response=analysis,
|
||||
agent=self,
|
||||
analysis_type=emission_type,
|
||||
response=analysis,
|
||||
template_vars={
|
||||
"character": emission.character if hasattr(emission, "character") else None,
|
||||
"character": emission.character
|
||||
if hasattr(emission, "character")
|
||||
else None,
|
||||
},
|
||||
dynamic_instructions=emission.dynamic_instructions
|
||||
dynamic_instructions=emission.dynamic_instructions,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if not analysis and self.analyze_scene:
|
||||
# analyze the scene for the next action
|
||||
analysis = await self.analyze_scene_for_next_action(
|
||||
emission_type,
|
||||
emission.character if hasattr(emission, "character") else None,
|
||||
self.analysis_length
|
||||
self.analysis_length,
|
||||
)
|
||||
|
||||
|
||||
await self.set_cached_analysis(emission_type, analysis)
|
||||
|
||||
|
||||
if not analysis:
|
||||
return
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(
|
||||
title="Scene Analysis",
|
||||
content=analysis
|
||||
)
|
||||
DynamicInstruction(title="Scene Analysis", content=analysis)
|
||||
)
|
||||
|
||||
async def on_rag_build_sub_instruction(self, emission:"RagBuildSubInstructionEmission"):
|
||||
|
||||
async def on_rag_build_sub_instruction(
|
||||
self, emission: "RagBuildSubInstructionEmission"
|
||||
):
|
||||
"""
|
||||
Injects the sub instruction into the analysis.
|
||||
"""
|
||||
sub_instruction = await self.analyze_scene_rag_build_sub_instruction()
|
||||
|
||||
|
||||
if sub_instruction:
|
||||
emission.sub_instruction = sub_instruction
|
||||
|
||||
|
||||
async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission):
|
||||
last_analysis = self.get_scene_state("scene_analysis")
|
||||
if last_analysis:
|
||||
emission.dynamic_instructions.append(DynamicInstruction(
|
||||
title="Scene Analysis",
|
||||
content=last_analysis
|
||||
))
|
||||
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(title="Scene Analysis", content=last_analysis)
|
||||
)
|
||||
|
||||
# helpers
|
||||
|
||||
async def get_cached_analysis(self, typ:str) -> str | None:
|
||||
|
||||
async def get_cached_analysis(self, typ: str) -> str | None:
|
||||
"""
|
||||
Returns the cached analysis for the given type.
|
||||
"""
|
||||
|
||||
|
||||
cached_analysis = self.get_scene_state(f"cached_analysis_{typ}")
|
||||
|
||||
|
||||
if not cached_analysis:
|
||||
return None
|
||||
|
||||
|
||||
fingerprint = self.context_fingerpint()
|
||||
|
||||
|
||||
if cached_analysis.get("fp") == fingerprint:
|
||||
return cached_analysis["guidance"]
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def set_cached_analysis(self, typ:str, analysis:str):
|
||||
|
||||
async def set_cached_analysis(self, typ: str, analysis: str):
|
||||
"""
|
||||
Sets the cached analysis for the given type.
|
||||
"""
|
||||
|
||||
|
||||
fingerprint = self.context_fingerpint()
|
||||
|
||||
|
||||
self.set_scene_states(
|
||||
**{f"cached_analysis_{typ}": {
|
||||
"fp": fingerprint,
|
||||
"guidance": analysis,
|
||||
}}
|
||||
**{
|
||||
f"cached_analysis_{typ}": {
|
||||
"fp": fingerprint,
|
||||
"guidance": analysis,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def analyze_scene_sub_type(self, analysis_type:str) -> str:
|
||||
|
||||
async def analyze_scene_sub_type(self, analysis_type: str) -> str:
|
||||
"""
|
||||
Analyzes the active agent context to figure out the appropriate sub type
|
||||
"""
|
||||
|
||||
|
||||
fn = getattr(self, f"analyze_scene_{analysis_type}_sub_type", None)
|
||||
|
||||
|
||||
if fn:
|
||||
return await fn()
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
async def analyze_scene_narration_sub_type(self) -> str:
|
||||
"""
|
||||
Analyzes the active agent context to figure out the appropriate sub type
|
||||
for narration analysis. (progress, query etc.)
|
||||
"""
|
||||
|
||||
|
||||
active_agent_context = active_agent.get()
|
||||
|
||||
|
||||
if not active_agent_context:
|
||||
return "progress"
|
||||
|
||||
|
||||
state = active_agent_context.state
|
||||
|
||||
|
||||
if state.get("narrator__query_narration"):
|
||||
return "query"
|
||||
|
||||
|
||||
if state.get("narrator__sensory_narration"):
|
||||
return "sensory"
|
||||
|
||||
@@ -306,70 +317,85 @@ class SceneAnalyzationMixin:
|
||||
if state.get("narrator__character"):
|
||||
return "visual-character"
|
||||
return "visual"
|
||||
|
||||
|
||||
if state.get("narrator__fn_narrate_character_entry"):
|
||||
return "progress-character-entry"
|
||||
|
||||
|
||||
if state.get("narrator__fn_narrate_character_exit"):
|
||||
return "progress-character-exit"
|
||||
|
||||
|
||||
return "progress"
|
||||
|
||||
|
||||
async def analyze_scene_rag_build_sub_instruction(self):
|
||||
"""
|
||||
Analyzes the active agent context to figure out the appropriate sub type
|
||||
for rag build sub instruction.
|
||||
"""
|
||||
|
||||
|
||||
active_agent_context = active_agent.get()
|
||||
|
||||
|
||||
if not active_agent_context:
|
||||
return ""
|
||||
|
||||
|
||||
state = active_agent_context.state
|
||||
|
||||
|
||||
if state.get("narrator__query_narration"):
|
||||
query = state["narrator__query"]
|
||||
if query.endswith("?"):
|
||||
return "Answer the following question: " + query
|
||||
else:
|
||||
return query
|
||||
|
||||
|
||||
narrative_direction = state.get("narrator__narrative_direction")
|
||||
|
||||
|
||||
if state.get("narrator__sensory_narration") and narrative_direction:
|
||||
return "Collect information that aids in describing the following sensory experience: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following sensory experience: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__visual_narration") and narrative_direction:
|
||||
return "Collect information that aids in describing the following visual experience: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following visual experience: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__fn_narrate_character_entry") and narrative_direction:
|
||||
return "Collect information that aids in describing the following character entry: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following character entry: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__fn_narrate_character_exit") and narrative_direction:
|
||||
return "Collect information that aids in describing the following character exit: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in describing the following character exit: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
if state.get("narrator__fn_narrate_progress"):
|
||||
return "Collect information that aids in progressing the story: " + narrative_direction
|
||||
|
||||
return (
|
||||
"Collect information that aids in progressing the story: "
|
||||
+ narrative_direction
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
# actions
|
||||
|
||||
@set_processing
|
||||
async def analyze_scene_for_next_action(self, typ:str, character:"Character"=None, length:int=1024) -> str:
|
||||
|
||||
async def analyze_scene_for_next_action(
|
||||
self, typ: str, character: "Character" = None, length: int = 1024
|
||||
) -> str:
|
||||
"""
|
||||
Analyzes the current scene progress and gives a suggestion for the next action.
|
||||
taken by the given actor.
|
||||
"""
|
||||
|
||||
|
||||
# deep analysis is only available if the scene has a layered history
|
||||
# and context investigation is enabled
|
||||
deep_analysis = (self.deep_analysis and self.context_investigation_available)
|
||||
deep_analysis = self.deep_analysis and self.context_investigation_available
|
||||
analysis_sub_type = await self.analyze_scene_sub_type(typ)
|
||||
|
||||
|
||||
template_vars = {
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
@@ -381,58 +407,60 @@ class SceneAnalyzationMixin:
|
||||
"analysis_type": typ,
|
||||
"analysis_sub_type": analysis_sub_type,
|
||||
}
|
||||
|
||||
emission = SceneAnalysisEmission(agent=self, template_vars=template_vars, analysis_type=typ)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before").send(
|
||||
emission
|
||||
|
||||
emission = SceneAnalysisEmission(
|
||||
agent=self, template_vars=template_vars, analysis_type=typ
|
||||
)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.before"
|
||||
).send(emission)
|
||||
|
||||
template_vars["dynamic_instructions"] = emission.dynamic_instructions
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
f"summarizer.analyze-scene-for-next-{typ}",
|
||||
self.client,
|
||||
f"investigate_{length}",
|
||||
vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
response = strip_partial_sentences(response)
|
||||
|
||||
|
||||
if not response.strip():
|
||||
return response
|
||||
|
||||
|
||||
if deep_analysis:
|
||||
|
||||
emission = SceneAnalysisDeepAnalysisEmission(
|
||||
agent=self,
|
||||
agent=self,
|
||||
analysis=response,
|
||||
analysis_type=typ,
|
||||
analysis_sub_type=analysis_sub_type,
|
||||
character=character,
|
||||
max_content_investigations=self.deep_analysis_max_context_investigations
|
||||
max_content_investigations=self.deep_analysis_max_context_investigations,
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").send(
|
||||
emission
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after_deep_analysis").send(
|
||||
emission
|
||||
)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").send(
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.before_deep_analysis"
|
||||
).send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.after_deep_analysis"
|
||||
).send(emission)
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.after"
|
||||
).send(
|
||||
SceneAnalysisEmission(
|
||||
agent=self,
|
||||
template_vars=template_vars,
|
||||
response=response,
|
||||
agent=self,
|
||||
template_vars=template_vars,
|
||||
response=response,
|
||||
analysis_type=typ,
|
||||
dynamic_instructions=emission.dynamic_instructions
|
||||
dynamic_instructions=emission.dynamic_instructions,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
self.set_context_states(scene_analysis=response)
|
||||
self.set_scene_states(scene_analysis=response)
|
||||
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
@@ -22,14 +22,12 @@ if TYPE_CHECKING:
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
|
||||
class ContextInvestigationMixin:
|
||||
|
||||
"""
|
||||
Summarizer agent mixin that provides functionality for context investigation
|
||||
through the layered history of the scene.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
actions["context_investigation"] = AgentAction(
|
||||
@@ -52,7 +50,7 @@ class ContextInvestigationMixin:
|
||||
{"label": "Short (256)", "value": "256"},
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
]
|
||||
],
|
||||
),
|
||||
"update_method": AgentActionConfig(
|
||||
type="text",
|
||||
@@ -62,57 +60,56 @@ class ContextInvestigationMixin:
|
||||
choices=[
|
||||
{"label": "Replace", "value": "replace"},
|
||||
{"label": "Smart Merge", "value": "merge"},
|
||||
]
|
||||
)
|
||||
}
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def context_investigation_enabled(self):
|
||||
return self.actions["context_investigation"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def context_investigation_available(self):
|
||||
return (
|
||||
self.context_investigation_enabled and
|
||||
self.layered_history_available
|
||||
)
|
||||
|
||||
return self.context_investigation_enabled and self.layered_history_available
|
||||
|
||||
@property
|
||||
def context_investigation_answer_length(self) -> int:
|
||||
return int(self.actions["context_investigation"].config["answer_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def context_investigation_update_method(self) -> str:
|
||||
return self.actions["context_investigation"].config["update_method"].value
|
||||
|
||||
|
||||
# signal connect
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect(
|
||||
self.on_inject_context_investigation
|
||||
)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.conversation.inject_instructions"
|
||||
).connect(self.on_inject_context_investigation)
|
||||
talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect(
|
||||
self.on_inject_context_investigation
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.director.guide.inject_instructions").connect(
|
||||
self.on_inject_context_investigation
|
||||
)
|
||||
talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").connect(
|
||||
self.on_summarization_scene_analysis_before_deep_analysis
|
||||
)
|
||||
|
||||
async def on_summarization_scene_analysis_before_deep_analysis(self, emission:SceneAnalysisDeepAnalysisEmission):
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.director.guide.inject_instructions"
|
||||
).connect(self.on_inject_context_investigation)
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.scene_analysis.before_deep_analysis"
|
||||
).connect(self.on_summarization_scene_analysis_before_deep_analysis)
|
||||
|
||||
async def on_summarization_scene_analysis_before_deep_analysis(
|
||||
self, emission: SceneAnalysisDeepAnalysisEmission
|
||||
):
|
||||
"""
|
||||
Handles context investigation for deep scene analysis.
|
||||
"""
|
||||
|
||||
|
||||
if not self.context_investigation_enabled:
|
||||
return
|
||||
|
||||
|
||||
suggested_investigations = await self.suggest_context_investigations(
|
||||
emission.analysis,
|
||||
emission.analysis_type,
|
||||
@@ -120,67 +117,72 @@ class ContextInvestigationMixin:
|
||||
max_calls=emission.max_content_investigations,
|
||||
character=emission.character,
|
||||
)
|
||||
|
||||
|
||||
response = emission.analysis
|
||||
|
||||
ci_calls:list[focal.Call] = await self.request_context_investigations(
|
||||
suggested_investigations,
|
||||
max_calls=emission.max_content_investigations
|
||||
|
||||
ci_calls: list[focal.Call] = await self.request_context_investigations(
|
||||
suggested_investigations, max_calls=emission.max_content_investigations
|
||||
)
|
||||
|
||||
|
||||
log.debug("analyze_scene_for_next_action", ci_calls=ci_calls)
|
||||
|
||||
|
||||
# append call queries and answers to the response
|
||||
ci_text = []
|
||||
for ci_call in ci_calls:
|
||||
try:
|
||||
ci_text.append(f"{ci_call.arguments['query']}\n{ci_call.result}")
|
||||
except KeyError as e:
|
||||
log.error("analyze_scene_for_next_action", error="Missing key in call", ci_call=ci_call)
|
||||
|
||||
context_investigation="\n\n".join(ci_text if ci_text else [])
|
||||
except KeyError:
|
||||
log.error(
|
||||
"analyze_scene_for_next_action",
|
||||
error="Missing key in call",
|
||||
ci_call=ci_call,
|
||||
)
|
||||
|
||||
context_investigation = "\n\n".join(ci_text if ci_text else [])
|
||||
current_context_investigation = self.get_scene_state("context_investigation")
|
||||
if current_context_investigation and context_investigation:
|
||||
if self.context_investigation_update_method == "merge":
|
||||
context_investigation = await self.update_context_investigation(
|
||||
current_context_investigation, context_investigation, response
|
||||
)
|
||||
|
||||
|
||||
self.set_scene_states(context_investigation=context_investigation)
|
||||
self.set_context_states(context_investigation=context_investigation)
|
||||
|
||||
|
||||
|
||||
async def on_inject_context_investigation(self, emission:ConversationAgentEmission | NarratorAgentEmission):
|
||||
|
||||
async def on_inject_context_investigation(
|
||||
self, emission: ConversationAgentEmission | NarratorAgentEmission
|
||||
):
|
||||
"""
|
||||
Injects context investigation into the conversation.
|
||||
"""
|
||||
|
||||
|
||||
if not self.context_investigation_enabled:
|
||||
return
|
||||
|
||||
|
||||
context_investigation = self.get_scene_state("context_investigation")
|
||||
log.debug("summarizer.on_inject_context_investigation", context_investigation=context_investigation, emission=emission)
|
||||
log.debug(
|
||||
"summarizer.on_inject_context_investigation",
|
||||
context_investigation=context_investigation,
|
||||
emission=emission,
|
||||
)
|
||||
if context_investigation:
|
||||
emission.dynamic_instructions.append(
|
||||
DynamicInstruction(
|
||||
title="Context Investigation",
|
||||
content=context_investigation
|
||||
title="Context Investigation", content=context_investigation
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
|
||||
@set_processing
|
||||
async def suggest_context_investigations(
|
||||
self,
|
||||
analysis:str,
|
||||
analysis_type:str,
|
||||
analysis_sub_type:str="",
|
||||
max_calls:int=3,
|
||||
character:"Character"=None,
|
||||
analysis: str,
|
||||
analysis_type: str,
|
||||
analysis_sub_type: str = "",
|
||||
max_calls: int = 3,
|
||||
character: "Character" = None,
|
||||
) -> str:
|
||||
|
||||
template_vars = {
|
||||
"max_tokens": self.client.max_token_length,
|
||||
"scene": self.scene,
|
||||
@@ -192,111 +194,119 @@ class ContextInvestigationMixin:
|
||||
"analysis_type": analysis_type,
|
||||
"analysis_sub_type": analysis_sub_type,
|
||||
}
|
||||
|
||||
|
||||
if not analysis_sub_type:
|
||||
template = f"summarizer.suggest-context-investigations-for-{analysis_type}"
|
||||
else:
|
||||
template = f"summarizer.suggest-context-investigations-for-{analysis_type}-{analysis_sub_type}"
|
||||
|
||||
log.debug("summarizer.suggest_context_investigations", template=template, template_vars=template_vars)
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarizer.suggest_context_investigations",
|
||||
template=template,
|
||||
template_vars=template_vars,
|
||||
)
|
||||
|
||||
response = await Prompt.request(
|
||||
template,
|
||||
self.client,
|
||||
"investigate_512",
|
||||
vars=template_vars,
|
||||
)
|
||||
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
@set_processing
|
||||
async def investigate_context(
|
||||
self,
|
||||
layer:int,
|
||||
index:int,
|
||||
query:str,
|
||||
analysis:str="",
|
||||
max_calls:int=3,
|
||||
pad_entries:int=5,
|
||||
self,
|
||||
layer: int,
|
||||
index: int,
|
||||
query: str,
|
||||
analysis: str = "",
|
||||
max_calls: int = 3,
|
||||
pad_entries: int = 5,
|
||||
) -> str:
|
||||
"""
|
||||
Processes a context investigation.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
|
||||
- layer: The layer to investigate
|
||||
- index: The index in the layer to investigate
|
||||
- query: The query to investigate
|
||||
- analysis: Scene analysis text
|
||||
- pad_entries: if > 0 will pad the entries with the given number of entries before and after the start and end index
|
||||
"""
|
||||
|
||||
log.debug("summarizer.investigate_context", layer=layer, index=index, query=query)
|
||||
|
||||
log.debug(
|
||||
"summarizer.investigate_context", layer=layer, index=index, query=query
|
||||
)
|
||||
entry = self.scene.layered_history[layer][index]
|
||||
|
||||
|
||||
layer_to_investigate = layer - 1
|
||||
|
||||
|
||||
start = max(entry["start"] - pad_entries, 0)
|
||||
end = entry["end"] + pad_entries + 1
|
||||
|
||||
|
||||
if layer_to_investigate == -1:
|
||||
entries = self.scene.archived_history[start:end]
|
||||
else:
|
||||
entries = self.scene.layered_history[layer_to_investigate][start:end]
|
||||
|
||||
async def answer(query:str, instructions:str) -> str:
|
||||
log.debug("Answering context investigation", query=query, instructions=answer)
|
||||
|
||||
|
||||
async def answer(query: str, instructions: str) -> str:
|
||||
log.debug(
|
||||
"Answering context investigation", query=query, instructions=answer
|
||||
)
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
|
||||
|
||||
return await world_state.analyze_history_and_follow_instructions(
|
||||
entries,
|
||||
f"{query}\n{instructions}",
|
||||
analysis=analysis,
|
||||
response_length=self.context_investigation_answer_length
|
||||
response_length=self.context_investigation_answer_length,
|
||||
)
|
||||
|
||||
|
||||
async def investigate_context(chapter_number:str, query:str) -> str:
|
||||
async def investigate_context(chapter_number: str, query: str) -> str:
|
||||
# look for \d.\d in the chapter number, extract as layer and index
|
||||
match = re.match(r"(\d+)\.(\d+)", chapter_number)
|
||||
if not match:
|
||||
log.error("summarizer.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
|
||||
log.error(
|
||||
"summarizer.investigate_context",
|
||||
error="Invalid chapter number",
|
||||
chapter_number=chapter_number,
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
layer = int(match.group(1))
|
||||
index = int(match.group(2))
|
||||
|
||||
return await self.investigate_context(layer-1, index-1, query, analysis=analysis, max_calls=max_calls)
|
||||
|
||||
|
||||
return await self.investigate_context(
|
||||
layer - 1, index - 1, query, analysis=analysis, max_calls=max_calls
|
||||
)
|
||||
|
||||
async def abort():
|
||||
log.debug("Aborting context investigation")
|
||||
|
||||
|
||||
focal_handler: focal.Focal = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
focal.Callback(
|
||||
name="investigate_context",
|
||||
arguments = [
|
||||
arguments=[
|
||||
focal.Argument(name="chapter_number", type="str"),
|
||||
focal.Argument(name="query", type="str")
|
||||
focal.Argument(name="query", type="str"),
|
||||
],
|
||||
fn=investigate_context
|
||||
fn=investigate_context,
|
||||
),
|
||||
focal.Callback(
|
||||
name="answer",
|
||||
arguments = [
|
||||
arguments=[
|
||||
focal.Argument(name="instructions", type="str"),
|
||||
focal.Argument(name="query", type="str")
|
||||
focal.Argument(name="query", type="str"),
|
||||
],
|
||||
fn=answer
|
||||
fn=answer,
|
||||
),
|
||||
focal.Callback(
|
||||
name="abort",
|
||||
fn=abort
|
||||
)
|
||||
focal.Callback(name="abort", fn=abort),
|
||||
],
|
||||
max_calls=max_calls,
|
||||
scene=self.scene,
|
||||
@@ -307,84 +317,86 @@ class ContextInvestigationMixin:
|
||||
entries=entries,
|
||||
analysis=analysis,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"summarizer.investigate-context",
|
||||
)
|
||||
|
||||
|
||||
log.debug("summarizer.investigate_context", calls=focal_handler.state.calls)
|
||||
|
||||
return focal_handler.state.calls
|
||||
|
||||
return focal_handler.state.calls
|
||||
|
||||
@set_processing
|
||||
async def request_context_investigations(
|
||||
self,
|
||||
analysis:str,
|
||||
max_calls:int=3,
|
||||
self,
|
||||
analysis: str,
|
||||
max_calls: int = 3,
|
||||
) -> list[focal.Call]:
|
||||
|
||||
"""
|
||||
Requests context investigations for the given analysis.
|
||||
"""
|
||||
|
||||
|
||||
async def abort():
|
||||
log.debug("Aborting context investigations")
|
||||
|
||||
async def investigate_context(chapter_number:str, query:str) -> str:
|
||||
|
||||
async def investigate_context(chapter_number: str, query: str) -> str:
|
||||
# look for \d.\d in the chapter number, extract as layer and index
|
||||
match = re.match(r"(\d+)\.(\d+)", chapter_number)
|
||||
if not match:
|
||||
log.error("summarizer.request_context_investigations.investigate_context", error="Invalid chapter number", chapter_number=chapter_number)
|
||||
log.error(
|
||||
"summarizer.request_context_investigations.investigate_context",
|
||||
error="Invalid chapter number",
|
||||
chapter_number=chapter_number,
|
||||
)
|
||||
return ""
|
||||
|
||||
layer = int(match.group(1))
|
||||
index = int(match.group(2))
|
||||
|
||||
|
||||
num_layers = len(self.scene.layered_history)
|
||||
|
||||
return await self.investigate_context(num_layers - layer, index-1, query, analysis, max_calls=max_calls)
|
||||
|
||||
|
||||
return await self.investigate_context(
|
||||
num_layers - layer, index - 1, query, analysis, max_calls=max_calls
|
||||
)
|
||||
|
||||
focal_handler: focal.Focal = focal.Focal(
|
||||
self.client,
|
||||
callbacks=[
|
||||
focal.Callback(
|
||||
name="investigate_context",
|
||||
arguments = [
|
||||
arguments=[
|
||||
focal.Argument(name="chapter_number", type="str"),
|
||||
focal.Argument(name="query", type="str")
|
||||
focal.Argument(name="query", type="str"),
|
||||
],
|
||||
fn=investigate_context
|
||||
fn=investigate_context,
|
||||
),
|
||||
focal.Callback(
|
||||
name="abort",
|
||||
fn=abort
|
||||
)
|
||||
focal.Callback(name="abort", fn=abort),
|
||||
],
|
||||
max_calls=max_calls,
|
||||
scene=self.scene,
|
||||
text=analysis
|
||||
text=analysis,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"summarizer.request-context-investigation",
|
||||
)
|
||||
|
||||
log.debug("summarizer.request_context_investigations", calls=focal_handler.state.calls)
|
||||
|
||||
return focal.collect_calls(
|
||||
focal_handler.state.calls,
|
||||
nested=True,
|
||||
filter=lambda c: c.name == "answer"
|
||||
|
||||
log.debug(
|
||||
"summarizer.request_context_investigations", calls=focal_handler.state.calls
|
||||
)
|
||||
|
||||
# return focal_handler.state.calls
|
||||
|
||||
|
||||
return focal.collect_calls(
|
||||
focal_handler.state.calls, nested=True, filter=lambda c: c.name == "answer"
|
||||
)
|
||||
|
||||
# return focal_handler.state.calls
|
||||
|
||||
@set_processing
|
||||
async def update_context_investigation(
|
||||
self,
|
||||
current_context_investigation:str,
|
||||
new_context_investigation:str,
|
||||
analysis:str,
|
||||
current_context_investigation: str,
|
||||
new_context_investigation: str,
|
||||
analysis: str,
|
||||
):
|
||||
response = await Prompt.request(
|
||||
"summarizer.update-context-investigation",
|
||||
@@ -398,5 +410,5 @@ class ContextInvestigationMixin:
|
||||
"max_tokens": self.client.max_token_length,
|
||||
},
|
||||
)
|
||||
|
||||
return response.strip()
|
||||
|
||||
return response.strip()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import structlog
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
@@ -24,11 +24,12 @@ talemate.emit.async_signals.register(
|
||||
"agent.summarization.layered_history.finalize",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LayeredHistoryFinalizeEmission(AgentEmission):
|
||||
entry: LayeredArchiveEntry | None = None
|
||||
summarization_history: list[str] = dataclasses.field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@property
|
||||
def response(self) -> str | None:
|
||||
return self.entry.text if self.entry else None
|
||||
@@ -38,22 +39,23 @@ class LayeredHistoryFinalizeEmission(AgentEmission):
|
||||
if self.entry:
|
||||
self.entry.text = value
|
||||
|
||||
|
||||
class SummaryLongerThanOriginalError(ValueError):
|
||||
def __init__(self, original_length:int, summarized_length:int):
|
||||
def __init__(self, original_length: int, summarized_length: int):
|
||||
self.original_length = original_length
|
||||
self.summarized_length = summarized_length
|
||||
super().__init__(f"Summarized text is longer than original text: {summarized_length} > {original_length}")
|
||||
super().__init__(
|
||||
f"Summarized text is longer than original text: {summarized_length} > {original_length}"
|
||||
)
|
||||
|
||||
|
||||
class LayeredHistoryMixin:
|
||||
|
||||
"""
|
||||
Summarizer agent mixin that provides functionality for maintaining a layered history.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_actions(cls, actions: dict[str, AgentAction]):
|
||||
|
||||
actions["layered_history"] = AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
@@ -80,7 +82,7 @@ class LayeredHistoryMixin:
|
||||
max=5,
|
||||
step=1,
|
||||
value=3,
|
||||
),
|
||||
),
|
||||
"max_process_tokens": AgentActionConfig(
|
||||
type="number",
|
||||
label="Maximum tokens to process",
|
||||
@@ -116,69 +118,71 @@ class LayeredHistoryMixin:
|
||||
{"label": "Medium (512)", "value": "512"},
|
||||
{"label": "Long (1024)", "value": "1024"},
|
||||
{"label": "Exhaustive (2048)", "value": "2048"},
|
||||
]
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_enabled(self):
|
||||
return self.actions["layered_history"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_threshold(self):
|
||||
return self.actions["layered_history"].config["threshold"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_max_process_tokens(self):
|
||||
return self.actions["layered_history"].config["max_process_tokens"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_max_layers(self):
|
||||
return self.actions["layered_history"].config["max_layers"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_chunk_size(self) -> int:
|
||||
return self.actions["layered_history"].config["chunk_size"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_analyze_chunks(self) -> bool:
|
||||
return self.actions["layered_history"].config["analyze_chunks"].value
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_response_length(self) -> int:
|
||||
return int(self.actions["layered_history"].config["response_length"].value)
|
||||
|
||||
|
||||
@property
|
||||
def layered_history_available(self):
|
||||
return self.layered_history_enabled and self.scene.layered_history and self.scene.layered_history[0]
|
||||
|
||||
|
||||
return (
|
||||
self.layered_history_enabled
|
||||
and self.scene.layered_history
|
||||
and self.scene.layered_history[0]
|
||||
)
|
||||
|
||||
# signals
|
||||
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("agent.summarization.after_build_archive").connect(
|
||||
self.on_after_build_archive
|
||||
)
|
||||
|
||||
|
||||
async def on_after_build_archive(self, emission:"BuildArchiveEmission"):
|
||||
talemate.emit.async_signals.get(
|
||||
"agent.summarization.after_build_archive"
|
||||
).connect(self.on_after_build_archive)
|
||||
|
||||
async def on_after_build_archive(self, emission: "BuildArchiveEmission"):
|
||||
"""
|
||||
After the archive has been built, we will update the layered history.
|
||||
"""
|
||||
|
||||
|
||||
if self.layered_history_enabled:
|
||||
await self.summarize_to_layered_history(
|
||||
generation_options=emission.generation_options
|
||||
)
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
async def _lh_split_and_summarize_chunks(
|
||||
self,
|
||||
self,
|
||||
chunks: list[dict],
|
||||
extra_context: str,
|
||||
generation_options: GenerationOptions | None = None,
|
||||
@@ -189,21 +193,29 @@ class LayeredHistoryMixin:
|
||||
"""
|
||||
summaries = []
|
||||
current_chunk = chunks.copy()
|
||||
|
||||
|
||||
while current_chunk:
|
||||
partial_chunk = []
|
||||
max_process_tokens = self.layered_history_max_process_tokens
|
||||
|
||||
|
||||
# Build partial chunk up to max_process_tokens
|
||||
while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens:
|
||||
while (
|
||||
current_chunk
|
||||
and util.count_tokens(
|
||||
"\n\n".join(chunk["text"] for chunk in partial_chunk)
|
||||
)
|
||||
< max_process_tokens
|
||||
):
|
||||
partial_chunk.append(current_chunk.pop(0))
|
||||
|
||||
text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk)
|
||||
|
||||
log.debug("_split_and_summarize_chunks",
|
||||
tokens_in_chunk=util.count_tokens(text_to_summarize),
|
||||
max_process_tokens=max_process_tokens)
|
||||
|
||||
|
||||
text_to_summarize = "\n\n".join(chunk["text"] for chunk in partial_chunk)
|
||||
|
||||
log.debug(
|
||||
"_split_and_summarize_chunks",
|
||||
tokens_in_chunk=util.count_tokens(text_to_summarize),
|
||||
max_process_tokens=max_process_tokens,
|
||||
)
|
||||
|
||||
summary_text = await self.summarize_events(
|
||||
text_to_summarize,
|
||||
extra_context=extra_context + "\n\n".join(summaries),
|
||||
@@ -213,9 +225,9 @@ class LayeredHistoryMixin:
|
||||
chunk_size=self.layered_history_chunk_size,
|
||||
)
|
||||
summaries.append(summary_text)
|
||||
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def _lh_validate_summary_length(self, summaries: list[str], original_length: int):
|
||||
"""
|
||||
Validates that the summarized text is not longer than the original.
|
||||
@@ -224,17 +236,19 @@ class LayeredHistoryMixin:
|
||||
summarized_length = util.count_tokens(summaries)
|
||||
if summarized_length > original_length:
|
||||
raise SummaryLongerThanOriginalError(original_length, summarized_length)
|
||||
|
||||
log.debug("_validate_summary_length",
|
||||
original_length=original_length,
|
||||
summarized_length=summarized_length)
|
||||
|
||||
|
||||
log.debug(
|
||||
"_validate_summary_length",
|
||||
original_length=original_length,
|
||||
summarized_length=summarized_length,
|
||||
)
|
||||
|
||||
def _lh_build_extra_context(self, layer_index: int) -> str:
|
||||
"""
|
||||
Builds extra context from compiled layered history for the given layer.
|
||||
"""
|
||||
return "\n\n".join(self.compile_layered_history(layer_index))
|
||||
|
||||
|
||||
def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]:
|
||||
"""
|
||||
Extracts timestamps from a chunk of entries.
|
||||
@@ -242,144 +256,156 @@ class LayeredHistoryMixin:
|
||||
"""
|
||||
if not chunk:
|
||||
return "PT1S", "PT1S", "PT1S"
|
||||
|
||||
ts = chunk[0].get('ts', 'PT1S')
|
||||
ts_start = chunk[0].get('ts_start', ts)
|
||||
ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts))
|
||||
|
||||
|
||||
ts = chunk[0].get("ts", "PT1S")
|
||||
ts_start = chunk[0].get("ts_start", ts)
|
||||
ts_end = chunk[-1].get("ts_end", chunk[-1].get("ts", ts))
|
||||
|
||||
return ts, ts_start, ts_end
|
||||
|
||||
|
||||
async def _lh_finalize_archive_entry(
|
||||
self,
|
||||
self,
|
||||
entry: LayeredArchiveEntry,
|
||||
summarization_history: list[str] | None = None,
|
||||
) -> LayeredArchiveEntry:
|
||||
"""
|
||||
Finalizes an archive entry by summarizing it and adding it to the layered history.
|
||||
"""
|
||||
|
||||
|
||||
emission = LayeredHistoryFinalizeEmission(
|
||||
agent=self,
|
||||
entry=entry,
|
||||
summarization_history=summarization_history,
|
||||
)
|
||||
|
||||
await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission)
|
||||
|
||||
|
||||
await talemate.emit.async_signals.get(
|
||||
"agent.summarization.layered_history.finalize"
|
||||
).send(emission)
|
||||
|
||||
return emission.entry
|
||||
|
||||
|
||||
# methods
|
||||
|
||||
def compile_layered_history(
|
||||
self,
|
||||
for_layer_index:int = None,
|
||||
as_objects:bool=False,
|
||||
include_base_layer:bool=False,
|
||||
max:int = None,
|
||||
self,
|
||||
for_layer_index: int = None,
|
||||
as_objects: bool = False,
|
||||
include_base_layer: bool = False,
|
||||
max: int = None,
|
||||
base_layer_end_id: str | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Starts at the last layer and compiles the layered history into a single
|
||||
list of events.
|
||||
|
||||
|
||||
We are iterating backwards, so the last layer will be the most granular.
|
||||
|
||||
|
||||
Each preceeding layer starts from the end of the the next layer.
|
||||
"""
|
||||
|
||||
|
||||
layered_history = self.scene.layered_history
|
||||
compiled = []
|
||||
next_layer_start = None
|
||||
|
||||
|
||||
len_layered_history = len(layered_history)
|
||||
|
||||
|
||||
for i in range(len_layered_history - 1, -1, -1):
|
||||
|
||||
if for_layer_index is not None:
|
||||
if i < for_layer_index:
|
||||
break
|
||||
|
||||
|
||||
log.debug("compilelayered history", i=i, next_layer_start=next_layer_start)
|
||||
|
||||
|
||||
if not layered_history[i]:
|
||||
continue
|
||||
|
||||
|
||||
entry_num = 1
|
||||
|
||||
for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]:
|
||||
|
||||
|
||||
for layered_history_entry in layered_history[i][
|
||||
next_layer_start if next_layer_start is not None else 0 :
|
||||
]:
|
||||
if base_layer_end_id:
|
||||
contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry(
|
||||
index=0,
|
||||
layer=i+1,
|
||||
**layered_history_entry)
|
||||
contained = entry_contained(
|
||||
self.scene,
|
||||
base_layer_end_id,
|
||||
HistoryEntry(index=0, layer=i + 1, **layered_history_entry),
|
||||
)
|
||||
if contained:
|
||||
log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id)
|
||||
log.debug(
|
||||
"compile_layered_history",
|
||||
contained=True,
|
||||
base_layer_end_id=base_layer_end_id,
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
text = f"{layered_history_entry['text']}"
|
||||
|
||||
if for_layer_index == i and max is not None and max <= layered_history_entry["end"]:
|
||||
|
||||
if (
|
||||
for_layer_index == i
|
||||
and max is not None
|
||||
and max <= layered_history_entry["end"]
|
||||
):
|
||||
break
|
||||
|
||||
|
||||
if as_objects:
|
||||
compiled.append({
|
||||
"text": text,
|
||||
"start": layered_history_entry["start"],
|
||||
"end": layered_history_entry["end"],
|
||||
"layer": i,
|
||||
"layer_r": len_layered_history - i,
|
||||
"ts_start": layered_history_entry["ts_start"],
|
||||
"index": entry_num,
|
||||
})
|
||||
compiled.append(
|
||||
{
|
||||
"text": text,
|
||||
"start": layered_history_entry["start"],
|
||||
"end": layered_history_entry["end"],
|
||||
"layer": i,
|
||||
"layer_r": len_layered_history - i,
|
||||
"ts_start": layered_history_entry["ts_start"],
|
||||
"index": entry_num,
|
||||
}
|
||||
)
|
||||
entry_num += 1
|
||||
else:
|
||||
compiled.append(text)
|
||||
|
||||
|
||||
next_layer_start = layered_history_entry["end"] + 1
|
||||
|
||||
|
||||
if i == 0 and include_base_layer:
|
||||
# we are are at layered history layer zero and inclusion of base layer (archived history) is requested
|
||||
# so we append the base layer to the compiled list, starting from
|
||||
# index `next_layer_start`
|
||||
|
||||
|
||||
entry_num = 1
|
||||
|
||||
for ah in self.scene.archived_history[next_layer_start or 0:]:
|
||||
|
||||
|
||||
for ah in self.scene.archived_history[next_layer_start or 0 :]:
|
||||
if base_layer_end_id and ah["id"] == base_layer_end_id:
|
||||
break
|
||||
|
||||
|
||||
text = f"{ah['text']}"
|
||||
if as_objects:
|
||||
compiled.append({
|
||||
"text": text,
|
||||
"start": ah["start"],
|
||||
"end": ah["end"],
|
||||
"layer": -1,
|
||||
"layer_r": 1,
|
||||
"ts": ah["ts"],
|
||||
"index": entry_num,
|
||||
})
|
||||
compiled.append(
|
||||
{
|
||||
"text": text,
|
||||
"start": ah["start"],
|
||||
"end": ah["end"],
|
||||
"layer": -1,
|
||||
"layer_r": 1,
|
||||
"ts": ah["ts"],
|
||||
"index": entry_num,
|
||||
}
|
||||
)
|
||||
entry_num += 1
|
||||
else:
|
||||
compiled.append(text)
|
||||
|
||||
|
||||
return compiled
|
||||
|
||||
|
||||
@set_processing
|
||||
async def summarize_to_layered_history(self, generation_options: GenerationOptions | None = None):
|
||||
|
||||
async def summarize_to_layered_history(
|
||||
self, generation_options: GenerationOptions | None = None
|
||||
):
|
||||
"""
|
||||
The layered history is a summarized archive with dynamic layers that
|
||||
will get less and less granular as the scene progresses.
|
||||
|
||||
|
||||
The most granular is still self.scene.archived_history, which holds
|
||||
all the base layer summarizations.
|
||||
|
||||
|
||||
self.scene.layered_history = [
|
||||
# first layer after archived_history
|
||||
[
|
||||
@@ -391,7 +417,7 @@ class LayeredHistoryMixin:
|
||||
},
|
||||
...
|
||||
],
|
||||
|
||||
|
||||
# second layer
|
||||
[
|
||||
{
|
||||
@@ -402,29 +428,29 @@ class LayeredHistoryMixin:
|
||||
},
|
||||
...
|
||||
],
|
||||
|
||||
|
||||
# additional layers
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
The same token threshold as for the base layer will be used for the
|
||||
layers.
|
||||
|
||||
|
||||
The same summarization function will be used for the layers.
|
||||
|
||||
|
||||
The next level layer will be generated automatically when the token
|
||||
threshold is reached.
|
||||
"""
|
||||
|
||||
|
||||
if not self.scene.archived_history:
|
||||
return # No base layer summaries to work with
|
||||
|
||||
|
||||
token_threshold = self.layered_history_threshold
|
||||
max_layers = self.layered_history_max_layers
|
||||
|
||||
if not hasattr(self.scene, 'layered_history'):
|
||||
if not hasattr(self.scene, "layered_history"):
|
||||
self.scene.layered_history = []
|
||||
|
||||
|
||||
layered_history = self.scene.layered_history
|
||||
|
||||
async def summarize_layer(source_layer, next_layer_index, start_from) -> bool:
|
||||
@@ -432,147 +458,192 @@ class LayeredHistoryMixin:
|
||||
current_tokens = 0
|
||||
start_index = start_from
|
||||
noop = True
|
||||
|
||||
total_tokens_in_previous_layer = util.count_tokens([
|
||||
entry['text'] for entry in source_layer
|
||||
])
|
||||
|
||||
total_tokens_in_previous_layer = util.count_tokens(
|
||||
[entry["text"] for entry in source_layer]
|
||||
)
|
||||
estimated_entries = total_tokens_in_previous_layer // token_threshold
|
||||
|
||||
for i in range(start_from, len(source_layer)):
|
||||
entry = source_layer[i]
|
||||
entry_tokens = util.count_tokens(entry['text'])
|
||||
|
||||
log.debug("summarize_to_layered_history", entry=entry["text"][:100]+"...", tokens=entry_tokens, current_layer=next_layer_index-1)
|
||||
|
||||
entry_tokens = util.count_tokens(entry["text"])
|
||||
|
||||
log.debug(
|
||||
"summarize_to_layered_history",
|
||||
entry=entry["text"][:100] + "...",
|
||||
tokens=entry_tokens,
|
||||
current_layer=next_layer_index - 1,
|
||||
)
|
||||
|
||||
if current_tokens + entry_tokens > token_threshold:
|
||||
if current_chunk:
|
||||
|
||||
try:
|
||||
# check if the next layer exists
|
||||
next_layer = layered_history[next_layer_index]
|
||||
except IndexError:
|
||||
# create the next layer
|
||||
layered_history.append([])
|
||||
log.debug("summarize_to_layered_history", created_layer=next_layer_index)
|
||||
log.debug(
|
||||
"summarize_to_layered_history",
|
||||
created_layer=next_layer_index,
|
||||
)
|
||||
next_layer = layered_history[next_layer_index]
|
||||
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
|
||||
|
||||
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(
|
||||
current_chunk
|
||||
)
|
||||
|
||||
extra_context = self._lh_build_extra_context(next_layer_index)
|
||||
|
||||
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
|
||||
text_length = util.count_tokens(
|
||||
"\n\n".join(chunk["text"] for chunk in current_chunk)
|
||||
)
|
||||
|
||||
num_entries_in_layer = len(layered_history[next_layer_index])
|
||||
|
||||
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True})
|
||||
|
||||
emit(
|
||||
"status",
|
||||
status="busy",
|
||||
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}",
|
||||
data={"cancellable": True},
|
||||
)
|
||||
|
||||
summaries = await self._lh_split_and_summarize_chunks(
|
||||
current_chunk,
|
||||
extra_context,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
noop = False
|
||||
|
||||
|
||||
# validate summary length
|
||||
self._lh_validate_summary_length(summaries, text_length)
|
||||
|
||||
next_layer.append(LayeredArchiveEntry(**{
|
||||
"start": start_index,
|
||||
"end": i,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
}).model_dump(exclude_none=True))
|
||||
|
||||
emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}")
|
||||
|
||||
|
||||
next_layer.append(
|
||||
LayeredArchiveEntry(
|
||||
**{
|
||||
"start": start_index,
|
||||
"end": i,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
}
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
emit(
|
||||
"status",
|
||||
status="busy",
|
||||
message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer + 1} / {estimated_entries}",
|
||||
)
|
||||
|
||||
current_chunk = []
|
||||
current_tokens = 0
|
||||
start_index = i
|
||||
|
||||
current_chunk.append(entry)
|
||||
current_tokens += entry_tokens
|
||||
|
||||
log.debug("summarize_to_layered_history", tokens=current_tokens, threshold=token_threshold, next_layer=next_layer_index)
|
||||
|
||||
|
||||
log.debug(
|
||||
"summarize_to_layered_history",
|
||||
tokens=current_tokens,
|
||||
threshold=token_threshold,
|
||||
next_layer=next_layer_index,
|
||||
)
|
||||
|
||||
return not noop
|
||||
|
||||
|
||||
|
||||
# First layer (always the base layer)
|
||||
has_been_updated = False
|
||||
|
||||
|
||||
try:
|
||||
|
||||
if not layered_history:
|
||||
layered_history.append([])
|
||||
log.debug("summarize_to_layered_history", layer="base", new_layer=True)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
|
||||
has_been_updated = await summarize_layer(
|
||||
self.scene.archived_history, 0, 0
|
||||
)
|
||||
elif layered_history[0]:
|
||||
# determine starting point by checking for `end` in the last entry
|
||||
last_entry = layered_history[0][-1]
|
||||
end = last_entry["end"]
|
||||
log.debug("summarize_to_layered_history", layer="base", start=end)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, end)
|
||||
has_been_updated = await summarize_layer(
|
||||
self.scene.archived_history, 0, end
|
||||
)
|
||||
else:
|
||||
log.debug("summarize_to_layered_history", layer="base", empty=True)
|
||||
has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0)
|
||||
|
||||
has_been_updated = await summarize_layer(
|
||||
self.scene.archived_history, 0, 0
|
||||
)
|
||||
|
||||
except SummaryLongerThanOriginalError as exc:
|
||||
log.error("summarize_to_layered_history", error=exc, layer="base")
|
||||
emit("status", status="error", message="Layered history update failed.")
|
||||
return
|
||||
except GenerationCancelled as e:
|
||||
log.info("Generation cancelled, stopping rebuild of historical layered history")
|
||||
emit("status", message="Rebuilding of layered history cancelled", status="info")
|
||||
log.info(
|
||||
"Generation cancelled, stopping rebuild of historical layered history"
|
||||
)
|
||||
emit(
|
||||
"status",
|
||||
message="Rebuilding of layered history cancelled",
|
||||
status="info",
|
||||
)
|
||||
handle_generation_cancelled(e)
|
||||
return
|
||||
|
||||
|
||||
# process layers
|
||||
async def update_layers() -> bool:
|
||||
noop = True
|
||||
for index in range(0, len(layered_history)):
|
||||
|
||||
# check against max layers
|
||||
if index + 1 > max_layers:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# check if the next layer exists
|
||||
next_layer = layered_history[index + 1]
|
||||
except IndexError:
|
||||
next_layer = None
|
||||
|
||||
|
||||
end = next_layer[-1]["end"] if next_layer else 0
|
||||
|
||||
|
||||
log.debug("summarize_to_layered_history", layer=index, start=end)
|
||||
summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0)
|
||||
|
||||
summarized = await summarize_layer(
|
||||
layered_history[index], index + 1, end if end else 0
|
||||
)
|
||||
|
||||
if summarized:
|
||||
noop = False
|
||||
|
||||
|
||||
return not noop
|
||||
|
||||
|
||||
try:
|
||||
while await update_layers():
|
||||
has_been_updated = True
|
||||
if has_been_updated:
|
||||
emit("status", status="success", message="Layered history updated.")
|
||||
|
||||
|
||||
except SummaryLongerThanOriginalError as exc:
|
||||
log.error("summarize_to_layered_history", error=exc, layer="subsequent")
|
||||
emit("status", status="error", message="Layered history update failed.")
|
||||
return
|
||||
except GenerationCancelled as e:
|
||||
log.info("Generation cancelled, stopping rebuild of historical layered history")
|
||||
emit("status", message="Rebuilding of layered history cancelled", status="info")
|
||||
log.info(
|
||||
"Generation cancelled, stopping rebuild of historical layered history"
|
||||
)
|
||||
emit(
|
||||
"status",
|
||||
message="Rebuilding of layered history cancelled",
|
||||
status="info",
|
||||
)
|
||||
handle_generation_cancelled(e)
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def summarize_entries_to_layered_history(
|
||||
self,
|
||||
entries: list[dict],
|
||||
self,
|
||||
entries: list[dict],
|
||||
next_layer_index: int,
|
||||
start_index: int,
|
||||
end_index: int,
|
||||
@@ -580,11 +651,11 @@ class LayeredHistoryMixin:
|
||||
) -> list[LayeredArchiveEntry]:
|
||||
"""
|
||||
Summarizes a list of entries into layered history entries.
|
||||
|
||||
|
||||
This method is used for regenerating specific history entries by processing
|
||||
their source entries. It chunks the entries based on the token threshold and
|
||||
summarizes each chunk into a LayeredArchiveEntry.
|
||||
|
||||
|
||||
Args:
|
||||
entries: List of dictionaries containing the text entries to summarize.
|
||||
Each entry should have at least a 'text' field and optionally
|
||||
@@ -597,12 +668,12 @@ class LayeredHistoryMixin:
|
||||
correspond to.
|
||||
generation_options: Optional generation options to pass to the summarization
|
||||
process.
|
||||
|
||||
|
||||
Returns:
|
||||
List of LayeredArchiveEntry objects containing the summarized text along
|
||||
with timestamp and index information. Currently returns a list with a
|
||||
single entry, but the structure supports multiple entries if needed.
|
||||
|
||||
|
||||
Notes:
|
||||
- The method respects the layered_history_threshold for chunking
|
||||
- Uses helper methods for timestamp extraction, context building, and
|
||||
@@ -611,63 +682,73 @@ class LayeredHistoryMixin:
|
||||
- The last entry is always included in the final chunk if it doesn't
|
||||
exceed the token threshold
|
||||
"""
|
||||
|
||||
|
||||
token_threshold = self.layered_history_threshold
|
||||
|
||||
|
||||
archive_entries = []
|
||||
summaries = []
|
||||
current_chunk = []
|
||||
current_tokens = 0
|
||||
|
||||
|
||||
ts = "PT1S"
|
||||
ts_start = "PT1S"
|
||||
ts_end = "PT1S"
|
||||
|
||||
|
||||
|
||||
for entry_index, entry in enumerate(entries):
|
||||
is_last_entry = entry_index == len(entries) - 1
|
||||
entry_tokens = util.count_tokens(entry['text'])
|
||||
|
||||
log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens)
|
||||
|
||||
entry_tokens = util.count_tokens(entry["text"])
|
||||
|
||||
log.debug(
|
||||
"summarize_entries_to_layered_history",
|
||||
entry=entry["text"][:100] + "...",
|
||||
entry_tokens=entry_tokens,
|
||||
current_layer=next_layer_index - 1,
|
||||
current_tokens=current_tokens,
|
||||
)
|
||||
|
||||
if current_tokens + entry_tokens > token_threshold or is_last_entry:
|
||||
|
||||
if is_last_entry and current_tokens + entry_tokens <= token_threshold:
|
||||
# if we are here because this is the last entry and adding it to
|
||||
# the current chunk would not exceed the token threshold, we will
|
||||
# add it to the current chunk
|
||||
current_chunk.append(entry)
|
||||
|
||||
|
||||
if current_chunk:
|
||||
ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk)
|
||||
|
||||
extra_context = self._lh_build_extra_context(next_layer_index)
|
||||
|
||||
text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk))
|
||||
text_length = util.count_tokens(
|
||||
"\n\n".join(chunk["text"] for chunk in current_chunk)
|
||||
)
|
||||
|
||||
summaries = await self._lh_split_and_summarize_chunks(
|
||||
current_chunk,
|
||||
extra_context,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
|
||||
# validate summary length
|
||||
self._lh_validate_summary_length(summaries, text_length)
|
||||
|
||||
archive_entry = LayeredArchiveEntry(**{
|
||||
"start": start_index,
|
||||
"end": end_index,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
})
|
||||
|
||||
archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n"))
|
||||
|
||||
|
||||
archive_entry = LayeredArchiveEntry(
|
||||
**{
|
||||
"start": start_index,
|
||||
"end": end_index,
|
||||
"ts": ts,
|
||||
"ts_start": ts_start,
|
||||
"ts_end": ts_end,
|
||||
"text": "\n\n".join(summaries),
|
||||
}
|
||||
)
|
||||
|
||||
archive_entry = await self._lh_finalize_archive_entry(
|
||||
archive_entry, extra_context.split("\n\n")
|
||||
)
|
||||
|
||||
archive_entries.append(archive_entry)
|
||||
|
||||
current_chunk.append(entry)
|
||||
current_tokens += entry_tokens
|
||||
|
||||
|
||||
return archive_entries
|
||||
|
||||
@@ -51,11 +51,10 @@ if not TTS:
|
||||
|
||||
|
||||
def parse_chunks(text: str) -> list[str]:
|
||||
|
||||
"""
|
||||
Takes a string and splits it into chunks based on punctuation.
|
||||
|
||||
In case of an error it will return the original text as a single chunk and
|
||||
|
||||
In case of an error it will return the original text as a single chunk and
|
||||
the error will be logged.
|
||||
"""
|
||||
|
||||
@@ -278,7 +277,6 @@ class TTSAgent(Agent):
|
||||
|
||||
@property
|
||||
def agent_details(self):
|
||||
|
||||
details = {
|
||||
"api": AgentDetail(
|
||||
icon="mdi-server-outline",
|
||||
@@ -645,7 +643,6 @@ class TTSAgent(Agent):
|
||||
# OPENAI
|
||||
|
||||
async def _generate_openai(self, text: str, chunk_size: int = 1024):
|
||||
|
||||
client = AsyncOpenAI(api_key=self.openai_api_key)
|
||||
|
||||
model = self.actions["openai"].config["model"].value
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import structlog
|
||||
|
||||
import talemate.agents.visual.automatic1111
|
||||
import talemate.agents.visual.comfyui
|
||||
import talemate.agents.visual.openai_image
|
||||
import talemate.agents.visual.automatic1111 # noqa: F401
|
||||
import talemate.agents.visual.comfyui # noqa: F401
|
||||
import talemate.agents.visual.openai_image # noqa: F401
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
@@ -25,11 +23,11 @@ from talemate.prompts.base import Prompt
|
||||
from .commands import * # noqa
|
||||
from .context import VIS_TYPES, VisualContext, VisualContextState, visual_context
|
||||
from .handlers import HANDLERS
|
||||
from .schema import RESOLUTION_MAP, RenderSettings
|
||||
from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles
|
||||
from .schema import RESOLUTION_MAP
|
||||
from .style import MAJOR_STYLES, STYLE_MAP, Style
|
||||
from .websocket_handler import VisualWebsocketHandler
|
||||
|
||||
import talemate.agents.visual.nodes
|
||||
import talemate.agents.visual.nodes # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"VisualAgent",
|
||||
@@ -114,7 +112,6 @@ class VisualBase(Agent):
|
||||
label="Process in Background",
|
||||
description="Process renders in the background",
|
||||
),
|
||||
|
||||
"_prompts": AgentAction(
|
||||
enabled=True,
|
||||
container=True,
|
||||
@@ -280,7 +277,6 @@ class VisualBase(Agent):
|
||||
await super().ready_check(task)
|
||||
|
||||
async def setup_check(self):
|
||||
|
||||
if not self.actions["automatic_setup"].enabled:
|
||||
return
|
||||
|
||||
@@ -289,7 +285,6 @@ class VisualBase(Agent):
|
||||
await getattr(self.client, f"visual_{backend.lower()}_setup")(self)
|
||||
|
||||
async def apply_config(self, *args, **kwargs):
|
||||
|
||||
try:
|
||||
backend = kwargs["actions"]["_config"]["config"]["backend"]["value"]
|
||||
except (KeyError, TypeError):
|
||||
@@ -312,7 +307,6 @@ class VisualBase(Agent):
|
||||
|
||||
backend_fn = getattr(self, f"{self.backend.lower()}_apply_config", None)
|
||||
if backend_fn:
|
||||
|
||||
if not backend_changed and was_disabled and self.enabled:
|
||||
# If the backend has not changed, but the agent was previously disabled
|
||||
# and is now enabled, we need to trigger the backend apply_config function
|
||||
@@ -336,7 +330,6 @@ class VisualBase(Agent):
|
||||
)
|
||||
|
||||
def prepare_prompt(self, prompt: str, styles: list[Style] = None) -> Style:
|
||||
|
||||
prompt_style = Style()
|
||||
prompt_style.load(prompt)
|
||||
|
||||
@@ -432,9 +425,8 @@ class VisualBase(Agent):
|
||||
async def generate(
|
||||
self, format: str = "portrait", prompt: str = None, automatic: bool = False
|
||||
):
|
||||
context: VisualContextState = visual_context.get()
|
||||
|
||||
context:VisualContextState = visual_context.get()
|
||||
|
||||
log.debug("visual generate", context=context)
|
||||
|
||||
if automatic and not self.allow_automatic_generation:
|
||||
@@ -466,7 +458,7 @@ class VisualBase(Agent):
|
||||
|
||||
thematic_style = self.default_style
|
||||
vis_type_styles = self.vis_type_styles(context.vis_type)
|
||||
prompt:Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
|
||||
prompt: Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style])
|
||||
|
||||
if context.vis_type == VIS_TYPES.CHARACTER:
|
||||
prompt.keywords.append("character portrait")
|
||||
@@ -488,12 +480,16 @@ class VisualBase(Agent):
|
||||
format = "portrait"
|
||||
|
||||
context.format = format
|
||||
|
||||
|
||||
can_generate_image = self.enabled and self.backend_ready
|
||||
|
||||
|
||||
if not context.prompt_only and not can_generate_image:
|
||||
emit("status", "Visual agent is not ready for image generation, will output prompt instead.", status="warning")
|
||||
|
||||
emit(
|
||||
"status",
|
||||
"Visual agent is not ready for image generation, will output prompt instead.",
|
||||
status="warning",
|
||||
)
|
||||
|
||||
# if prompt_only, we don't need to generate an image
|
||||
# instead we emit a system message with the prompt
|
||||
if context.prompt_only or not can_generate_image:
|
||||
@@ -509,20 +505,24 @@ class VisualBase(Agent):
|
||||
"title": f"Visual Prompt - {context.title}",
|
||||
"display": "tonal",
|
||||
"as_markdown": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if not can_generate_image:
|
||||
return
|
||||
|
||||
|
||||
# Call the backend specific generate function
|
||||
|
||||
backend = self.backend
|
||||
fn = f"{backend.lower()}_generate"
|
||||
|
||||
log.info(
|
||||
"visual generate", backend=backend, prompt=prompt, format=format, context=context
|
||||
"visual generate",
|
||||
backend=backend,
|
||||
prompt=prompt,
|
||||
format=format,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if not hasattr(self, fn):
|
||||
@@ -538,7 +538,6 @@ class VisualBase(Agent):
|
||||
|
||||
@set_processing
|
||||
async def generate_environment_prompt(self, instructions: str = None):
|
||||
|
||||
with RevisionDisabled():
|
||||
response = await Prompt.request(
|
||||
"visual.generate-environment-prompt",
|
||||
@@ -556,7 +555,6 @@ class VisualBase(Agent):
|
||||
async def generate_character_prompt(
|
||||
self, character_name: str, instructions: str = None
|
||||
):
|
||||
|
||||
character = self.scene.get_character(character_name)
|
||||
|
||||
with RevisionDisabled():
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
from .schema import RenderSettings
|
||||
from .style import Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.automatic1111")
|
||||
|
||||
@@ -58,9 +51,9 @@ SAMPLING_SCHEDULES = [
|
||||
|
||||
SAMPLING_SCHEDULES = sorted(SAMPLING_SCHEDULES, key=lambda x: x["label"])
|
||||
|
||||
|
||||
@register(backend_name="automatic1111", label="AUTOMATIC1111")
|
||||
class Automatic1111Mixin:
|
||||
|
||||
automatic1111_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
@@ -139,14 +132,14 @@ class Automatic1111Mixin:
|
||||
@property
|
||||
def automatic1111_sampling_method(self):
|
||||
return self.actions["automatic1111"].config["sampling_method"].value
|
||||
|
||||
|
||||
@property
|
||||
def automatic1111_schedule_type(self):
|
||||
return self.actions["automatic1111"].config["schedule_type"].value
|
||||
|
||||
|
||||
@property
|
||||
def automatic1111_cfg(self):
|
||||
return self.actions["automatic1111"].config["cfg"].value
|
||||
return self.actions["automatic1111"].config["cfg"].value
|
||||
|
||||
async def automatic1111_generate(self, prompt: Style, format: str):
|
||||
url = self.api_url
|
||||
@@ -162,14 +155,16 @@ class Automatic1111Mixin:
|
||||
"height": resolution.height,
|
||||
"cfg_scale": self.automatic1111_cfg,
|
||||
"sampler_name": self.automatic1111_sampling_method,
|
||||
"scheduler": self.automatic1111_schedule_type
|
||||
"scheduler": self.automatic1111_schedule_type,
|
||||
}
|
||||
|
||||
log.info("automatic1111_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url=f"{url}/sdapi/v1/txt2img", json=payload, timeout=self.generate_timeout
|
||||
url=f"{url}/sdapi/v1/txt2img",
|
||||
json=payload,
|
||||
timeout=self.generate_timeout,
|
||||
)
|
||||
|
||||
r = response.json()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@@ -10,13 +9,12 @@ import urllib.parse
|
||||
import httpx
|
||||
import pydantic
|
||||
import structlog
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
from .style import Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.comfyui")
|
||||
|
||||
@@ -25,7 +23,6 @@ class Workflow(pydantic.BaseModel):
|
||||
nodes: dict
|
||||
|
||||
def set_resolution(self, resolution: Resolution):
|
||||
|
||||
# will collect all latent image nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Resolution"
|
||||
@@ -55,7 +52,6 @@ class Workflow(pydantic.BaseModel):
|
||||
latent_image_node["inputs"]["height"] = resolution.height
|
||||
|
||||
def set_prompt(self, prompt: str, negative_prompt: str = None):
|
||||
|
||||
# will collect all CLIPTextEncode nodes
|
||||
|
||||
# if there is multiple will look for the one with the
|
||||
@@ -79,7 +75,6 @@ class Workflow(pydantic.BaseModel):
|
||||
negative_prompt_node = None
|
||||
|
||||
for node_id, node in self.nodes.items():
|
||||
|
||||
if node["class_type"] == "CLIPTextEncode":
|
||||
if not positive_prompt_node:
|
||||
positive_prompt_node = node
|
||||
@@ -102,7 +97,6 @@ class Workflow(pydantic.BaseModel):
|
||||
negative_prompt_node["inputs"]["text"] = negative_prompt
|
||||
|
||||
def set_checkpoint(self, checkpoint: str):
|
||||
|
||||
# will collect all CheckpointLoaderSimple nodes
|
||||
# if there is multiple will look for the one with the
|
||||
# title "Talemate Load Checkpoint"
|
||||
@@ -139,7 +133,6 @@ class Workflow(pydantic.BaseModel):
|
||||
|
||||
@register(backend_name="comfyui", label="ComfyUI")
|
||||
class ComfyUIMixin:
|
||||
|
||||
comfyui_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
@@ -287,7 +280,9 @@ class ComfyUIMixin:
|
||||
log.info("comfyui_generate", payload=payload, url=url)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=f"{url}/prompt", json=payload, timeout=self.generate_timeout)
|
||||
response = await client.post(
|
||||
url=f"{url}/prompt", json=payload, timeout=self.generate_timeout
|
||||
)
|
||||
|
||||
log.info("comfyui_generate", response=response.text)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class VisualContextState(pydantic.BaseModel):
|
||||
format: Union[str, None] = None
|
||||
replace: bool = False
|
||||
prompt_only: bool = False
|
||||
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
if self.vis_type == VIS_TYPES.ENVIRONMENT:
|
||||
|
||||
@@ -7,7 +7,6 @@ HANDLERS = {}
|
||||
|
||||
|
||||
class register:
|
||||
|
||||
def __init__(self, backend_name: str, label: str):
|
||||
self.backend_name = backend_name
|
||||
self.label = label
|
||||
|
||||
@@ -6,23 +6,27 @@ from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.visual")
|
||||
|
||||
|
||||
@register("agents/visual/Settings")
|
||||
class VisualSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render visual agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "visual"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "visual"
|
||||
|
||||
def __init__(self, title="Visual Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
@register("agents/visual/GenerateCharacterPortrait")
|
||||
class GenerateCharacterPortrait(AgentNode):
|
||||
"""
|
||||
Generates a portrait for a character
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "visual"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "visual"
|
||||
|
||||
class Fields:
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
@@ -30,33 +34,34 @@ class GenerateCharacterPortrait(AgentNode):
|
||||
description="instructions for the portrait",
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Generate Character Portrait", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
self.add_input("instructions", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.set_property("instructions", UNRESOLVED)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("character", socket_type="character")
|
||||
self.add_output("portrait", socket_type="image")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character = self.get_input_value("character")
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
|
||||
|
||||
portrait = await self.agent.generate_character_portrait(
|
||||
character_name=character.name,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": self.get_input_value("state"),
|
||||
"character": character,
|
||||
"portrait": portrait
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values(
|
||||
{
|
||||
"state": self.get_input_value("state"),
|
||||
"character": character,
|
||||
"portrait": portrait,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,31 +1,21 @@
|
||||
import base64
|
||||
import io
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
from PIL import Image
|
||||
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConditional,
|
||||
AgentActionConfig,
|
||||
AgentDetail,
|
||||
set_processing,
|
||||
)
|
||||
|
||||
from .handlers import register
|
||||
from .schema import RenderSettings, Resolution
|
||||
from .style import STYLE_MAP, Style
|
||||
from .style import Style
|
||||
|
||||
log = structlog.get_logger("talemate.agents.visual.openai_image")
|
||||
|
||||
|
||||
@register(backend_name="openai_image", label="OpenAI")
|
||||
class OpenAIImageMixin:
|
||||
|
||||
openai_image_default_render_settings = RenderSettings()
|
||||
|
||||
EXTEND_ACTIONS = {
|
||||
|
||||
@@ -83,7 +83,9 @@ class Style(pydantic.BaseModel):
|
||||
# Almost taken straight from some of the fooocus style presets, credit goes to the original author
|
||||
|
||||
STYLE_MAP["digital_art"] = Style(
|
||||
keywords="in the style of a digital artwork, masterpiece, best quality, high detail".split(", "),
|
||||
keywords="in the style of a digital artwork, masterpiece, best quality, high detail".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
@@ -95,12 +97,16 @@ STYLE_MAP["concept_art"] = Style(
|
||||
)
|
||||
|
||||
STYLE_MAP["ink_illustration"] = Style(
|
||||
keywords="in the style of ink illustration, painting, masterpiece, best quality".split(", "),
|
||||
keywords="in the style of ink illustration, painting, masterpiece, best quality".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo".split(", "),
|
||||
)
|
||||
|
||||
STYLE_MAP["anime"] = Style(
|
||||
keywords="in the style of anime, masterpiece, best quality, illustration".split(", "),
|
||||
keywords="in the style of anime, masterpiece, best quality, illustration".split(
|
||||
", "
|
||||
),
|
||||
negative_keywords="text, watermark, low quality, blurry, photo, 3d".split(", "),
|
||||
)
|
||||
|
||||
|
||||
@@ -56,7 +56,6 @@ class VisualWebsocketHandler(Plugin):
|
||||
scene = self.scene
|
||||
|
||||
if context and context.character_name:
|
||||
|
||||
character = scene.get_character(context.character_name)
|
||||
|
||||
if not character:
|
||||
|
||||
@@ -21,7 +21,13 @@ from talemate.scene_message import (
|
||||
from talemate.util.response import extract_list
|
||||
|
||||
|
||||
from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing
|
||||
from talemate.agents.base import (
|
||||
Agent,
|
||||
AgentAction,
|
||||
AgentActionConfig,
|
||||
AgentEmission,
|
||||
set_processing,
|
||||
)
|
||||
from talemate.agents.registry import register
|
||||
|
||||
|
||||
@@ -57,10 +63,7 @@ class TimePassageEmission(WorldStateAgentEmission):
|
||||
|
||||
|
||||
@register()
|
||||
class WorldStateAgent(
|
||||
CharacterProgressionMixin,
|
||||
Agent
|
||||
):
|
||||
class WorldStateAgent(CharacterProgressionMixin, Agent):
|
||||
"""
|
||||
An agent that handles world state related tasks.
|
||||
"""
|
||||
@@ -91,7 +94,7 @@ class WorldStateAgent(
|
||||
min=1,
|
||||
max=100,
|
||||
step=1,
|
||||
)
|
||||
),
|
||||
},
|
||||
),
|
||||
"update_reinforcements": AgentAction(
|
||||
@@ -140,7 +143,7 @@ class WorldStateAgent(
|
||||
@property
|
||||
def experimental(self):
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def initial_update(self):
|
||||
return self.actions["update_world_state"].config["initial"].value
|
||||
@@ -148,7 +151,9 @@ class WorldStateAgent(
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||
talemate.emit.async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init_after)
|
||||
talemate.emit.async_signals.get("scene_loop_init_after").connect(
|
||||
self.on_scene_loop_init_after
|
||||
)
|
||||
|
||||
async def advance_time(self, duration: str, narrative: str = None):
|
||||
"""
|
||||
@@ -183,13 +188,13 @@ class WorldStateAgent(
|
||||
|
||||
if not self.initial_update:
|
||||
return
|
||||
|
||||
|
||||
if self.get_scene_state("inital_update_done"):
|
||||
return
|
||||
|
||||
await self.scene.world_state.request_update()
|
||||
self.set_scene_states(inital_update_done=True)
|
||||
|
||||
|
||||
async def on_game_loop(self, emission: GameLoopEvent):
|
||||
"""
|
||||
Called when a conversation is generated
|
||||
@@ -339,13 +344,13 @@ class WorldStateAgent(
|
||||
|
||||
context = await memory_agent.multi_query(queries, iterate=3)
|
||||
|
||||
#log.debug(
|
||||
# log.debug(
|
||||
# "analyze_text_and_extract_context_via_queries",
|
||||
# goal=goal,
|
||||
# text=text,
|
||||
# queries=queries,
|
||||
# context=context,
|
||||
#)
|
||||
# )
|
||||
|
||||
return context
|
||||
|
||||
@@ -356,7 +361,6 @@ class WorldStateAgent(
|
||||
instruction: str,
|
||||
short: bool = False,
|
||||
):
|
||||
|
||||
kind = "analyze_freeform_short" if short else "analyze_freeform"
|
||||
|
||||
response = await Prompt.request(
|
||||
@@ -408,21 +412,20 @@ class WorldStateAgent(
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@set_processing
|
||||
async def analyze_history_and_follow_instructions(
|
||||
self,
|
||||
entries: list[dict],
|
||||
instructions: str,
|
||||
analysis: str = "",
|
||||
response_length: int = 512
|
||||
response_length: int = 512,
|
||||
) -> str:
|
||||
|
||||
"""
|
||||
Takes a list of archived_history or layered_history entries
|
||||
and follows the instructions to generate a response.
|
||||
"""
|
||||
|
||||
|
||||
response = await Prompt.request(
|
||||
"world_state.analyze-history-and-follow-instructions",
|
||||
self.client,
|
||||
@@ -436,7 +439,7 @@ class WorldStateAgent(
|
||||
"response_length": response_length,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return response.strip()
|
||||
|
||||
@set_processing
|
||||
@@ -480,7 +483,7 @@ class WorldStateAgent(
|
||||
for line in response.split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
if not ":" in line:
|
||||
if ":" not in line:
|
||||
break
|
||||
name, value = line.split(":", 1)
|
||||
data[name.strip()] = value.strip()
|
||||
@@ -542,24 +545,33 @@ class WorldStateAgent(
|
||||
"""
|
||||
Queries a single re-inforcement
|
||||
"""
|
||||
|
||||
|
||||
if isinstance(character, self.scene.Character):
|
||||
character = character.name
|
||||
|
||||
|
||||
message = None
|
||||
idx, reinforcement = await self.scene.world_state.find_reinforcement(
|
||||
question, character
|
||||
)
|
||||
|
||||
if not reinforcement:
|
||||
log.warning(f"Reinforcement not found", question=question, character=character)
|
||||
log.warning(
|
||||
"Reinforcement not found", question=question, character=character
|
||||
)
|
||||
return
|
||||
|
||||
message = ReinforcementMessage(message="")
|
||||
message.set_source("world_state", "update_reinforcement", question=question, character=character)
|
||||
|
||||
message.set_source(
|
||||
"world_state",
|
||||
"update_reinforcement",
|
||||
question=question,
|
||||
character=character,
|
||||
)
|
||||
|
||||
if reset and reinforcement.insert == "sequential":
|
||||
self.scene.pop_history(typ="reinforcement", meta_hash=message.meta_hash, all=True)
|
||||
self.scene.pop_history(
|
||||
typ="reinforcement", meta_hash=message.meta_hash, all=True
|
||||
)
|
||||
|
||||
if reinforcement.insert == "sequential":
|
||||
kind = "analyze_freeform_medium_short"
|
||||
@@ -594,7 +606,7 @@ class WorldStateAgent(
|
||||
|
||||
reinforcement.answer = answer
|
||||
reinforcement.due = reinforcement.interval
|
||||
|
||||
|
||||
# remove any recent previous reinforcement message with same question
|
||||
# to avoid overloading the near history with reinforcement messages
|
||||
if not reset:
|
||||
@@ -818,4 +830,4 @@ class WorldStateAgent(
|
||||
kwargs=kwargs,
|
||||
error=e,
|
||||
)
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
import re
|
||||
from talemate.agents.base import (
|
||||
set_processing,
|
||||
AgentAction,
|
||||
AgentActionConfig
|
||||
)
|
||||
from talemate.prompts import Prompt
|
||||
from talemate.agents.base import set_processing, AgentAction, AgentActionConfig
|
||||
from talemate.instance import get_agent
|
||||
from talemate.events import GameLoopEvent
|
||||
from talemate.status import set_loading
|
||||
from talemate.emit import emit
|
||||
|
||||
import talemate.emit.async_signals
|
||||
import talemate.game.focal as focal
|
||||
@@ -23,8 +16,8 @@ if TYPE_CHECKING:
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
class CharacterProgressionMixin:
|
||||
|
||||
"""
|
||||
World-state manager agent mixin that handles tracking of character progression
|
||||
and proposal of updates to character profiles.
|
||||
@@ -55,13 +48,13 @@ class CharacterProgressionMixin:
|
||||
type="bool",
|
||||
label="Propose as suggestions",
|
||||
description="Propose changes as suggestions that need to be manually accepted.",
|
||||
value=True
|
||||
value=True,
|
||||
),
|
||||
"player_character": AgentActionConfig(
|
||||
type="bool",
|
||||
label="Player character",
|
||||
description="Track the player character's progression.",
|
||||
value=True
|
||||
value=True,
|
||||
),
|
||||
"max_changes": AgentActionConfig(
|
||||
type="number",
|
||||
@@ -70,16 +63,16 @@ class CharacterProgressionMixin:
|
||||
value=1,
|
||||
min=1,
|
||||
max=5,
|
||||
)
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# config property helpers
|
||||
|
||||
|
||||
# config property helpers
|
||||
|
||||
@property
|
||||
def character_progression_enabled(self) -> bool:
|
||||
return self.actions["character_progression"].enabled
|
||||
|
||||
|
||||
@property
|
||||
def character_progression_frequency(self) -> int:
|
||||
return self.actions["character_progression"].config["frequency"].value
|
||||
@@ -91,7 +84,7 @@ class CharacterProgressionMixin:
|
||||
@property
|
||||
def character_progression_max_changes(self) -> int:
|
||||
return self.actions["character_progression"].config["max_changes"].value
|
||||
|
||||
|
||||
@property
|
||||
def character_progression_as_suggestions(self) -> bool:
|
||||
return self.actions["character_progression"].config["as_suggestions"].value
|
||||
@@ -100,8 +93,9 @@ class CharacterProgressionMixin:
|
||||
|
||||
def connect(self, scene):
|
||||
super().connect(scene)
|
||||
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop_track_character_progression)
|
||||
|
||||
talemate.emit.async_signals.get("game_loop").connect(
|
||||
self.on_game_loop_track_character_progression
|
||||
)
|
||||
|
||||
async def on_game_loop_track_character_progression(self, emission: GameLoopEvent):
|
||||
"""
|
||||
@@ -110,58 +104,69 @@ class CharacterProgressionMixin:
|
||||
|
||||
if not self.enabled or not self.character_progression_enabled:
|
||||
return
|
||||
|
||||
|
||||
log.debug("on_game_loop_track_character_progression", scene=self.scene)
|
||||
|
||||
rounds_since_last_check = self.get_scene_state("rounds_since_last_character_progression_check", 0)
|
||||
|
||||
|
||||
rounds_since_last_check = self.get_scene_state(
|
||||
"rounds_since_last_character_progression_check", 0
|
||||
)
|
||||
|
||||
if rounds_since_last_check < self.character_progression_frequency:
|
||||
rounds_since_last_check += 1
|
||||
self.set_scene_states(rounds_since_last_character_progression_check=rounds_since_last_check)
|
||||
self.set_scene_states(
|
||||
rounds_since_last_character_progression_check=rounds_since_last_check
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
self.set_scene_states(rounds_since_last_character_progression_check=0)
|
||||
|
||||
|
||||
for character in self.scene.characters:
|
||||
|
||||
if character.is_player and not self.character_progression_player_character:
|
||||
continue
|
||||
|
||||
calls:list[focal.Call] = await self.determine_character_development(character)
|
||||
await self.character_progression_process_calls(
|
||||
character = character,
|
||||
calls = calls,
|
||||
as_suggestions = self.character_progression_as_suggestions,
|
||||
|
||||
calls: list[focal.Call] = await self.determine_character_development(
|
||||
character
|
||||
)
|
||||
|
||||
await self.character_progression_process_calls(
|
||||
character=character,
|
||||
calls=calls,
|
||||
as_suggestions=self.character_progression_as_suggestions,
|
||||
)
|
||||
|
||||
# methods
|
||||
|
||||
|
||||
@set_processing
|
||||
async def character_progression_process_calls(self, character:"Character", calls:list[focal.Call], as_suggestions:bool=True):
|
||||
|
||||
world_state_manager:WorldStateManager = self.scene.world_state_manager
|
||||
async def character_progression_process_calls(
|
||||
self,
|
||||
character: "Character",
|
||||
calls: list[focal.Call],
|
||||
as_suggestions: bool = True,
|
||||
):
|
||||
world_state_manager: WorldStateManager = self.scene.world_state_manager
|
||||
if as_suggestions:
|
||||
await world_state_manager.add_suggestion(
|
||||
Suggestion(
|
||||
name=character.name,
|
||||
type="character",
|
||||
id=f"character-{character.name}",
|
||||
proposals=calls
|
||||
proposals=calls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for call in calls:
|
||||
# changes will be applied directly to the character
|
||||
if call.name in ["add_attribute", "update_attribute"]:
|
||||
await character.set_base_attribute(call.arguments["name"], call.result)
|
||||
await character.set_base_attribute(
|
||||
call.arguments["name"], call.result
|
||||
)
|
||||
elif call.name == "remove_attribute":
|
||||
await character.set_base_attribute(call.arguments["name"], None)
|
||||
elif call.name == "update_description":
|
||||
await character.set_description(call.result)
|
||||
|
||||
|
||||
@set_processing
|
||||
async def determine_character_development(
|
||||
self,
|
||||
self,
|
||||
character: "Character",
|
||||
generation_options: world_state_templates.GenerationOptions | None = None,
|
||||
instructions: str = None,
|
||||
@@ -169,95 +174,96 @@ class CharacterProgressionMixin:
|
||||
"""
|
||||
Determine character development
|
||||
"""
|
||||
|
||||
log.debug("determine_character_development", character=character, generation_options=generation_options)
|
||||
|
||||
|
||||
log.debug(
|
||||
"determine_character_development",
|
||||
character=character,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
creator = get_agent("creator")
|
||||
|
||||
|
||||
@set_loading("Generating character attribute", cancellable=True)
|
||||
async def add_attribute(name: str, instructions: str) -> str:
|
||||
return await creator.generate_character_attribute(
|
||||
character,
|
||||
attribute_name = name,
|
||||
instructions = instructions,
|
||||
generation_options = generation_options,
|
||||
attribute_name=name,
|
||||
instructions=instructions,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
|
||||
@set_loading("Generating character attribute", cancellable=True)
|
||||
async def update_attribute(name: str, instructions: str) -> str:
|
||||
return await creator.generate_character_attribute(
|
||||
character,
|
||||
attribute_name = name,
|
||||
instructions = instructions,
|
||||
original = character.base_attributes.get(name),
|
||||
generation_options = generation_options,
|
||||
attribute_name=name,
|
||||
instructions=instructions,
|
||||
original=character.base_attributes.get(name),
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
async def remove_attribute(name: str, reason:str) -> str:
|
||||
|
||||
async def remove_attribute(name: str, reason: str) -> str:
|
||||
return None
|
||||
|
||||
|
||||
@set_loading("Generating character description", cancellable=True)
|
||||
async def update_description(instructions: str) -> str:
|
||||
return await creator.generate_character_detail(
|
||||
character,
|
||||
detail_name = "description",
|
||||
instructions = instructions,
|
||||
original = character.description,
|
||||
detail_name="description",
|
||||
instructions=instructions,
|
||||
original=character.description,
|
||||
length=1024,
|
||||
generation_options = generation_options,
|
||||
generation_options=generation_options,
|
||||
)
|
||||
|
||||
|
||||
focal_handler = focal.Focal(
|
||||
self.client,
|
||||
|
||||
# callbacks
|
||||
callbacks = [
|
||||
callbacks=[
|
||||
focal.Callback(
|
||||
name = "add_attribute",
|
||||
arguments = [
|
||||
name="add_attribute",
|
||||
arguments=[
|
||||
focal.Argument(name="name", type="str"),
|
||||
focal.Argument(name="instructions", type="str"),
|
||||
],
|
||||
fn = add_attribute
|
||||
fn=add_attribute,
|
||||
),
|
||||
focal.Callback(
|
||||
name = "update_attribute",
|
||||
arguments = [
|
||||
name="update_attribute",
|
||||
arguments=[
|
||||
focal.Argument(name="name", type="str"),
|
||||
focal.Argument(name="instructions", type="str"),
|
||||
],
|
||||
fn = update_attribute
|
||||
fn=update_attribute,
|
||||
),
|
||||
focal.Callback(
|
||||
name = "remove_attribute",
|
||||
arguments = [
|
||||
name="remove_attribute",
|
||||
arguments=[
|
||||
focal.Argument(name="name", type="str"),
|
||||
focal.Argument(name="reason", type="str"),
|
||||
],
|
||||
fn = remove_attribute
|
||||
fn=remove_attribute,
|
||||
),
|
||||
focal.Callback(
|
||||
name = "update_description",
|
||||
arguments = [
|
||||
name="update_description",
|
||||
arguments=[
|
||||
focal.Argument(name="instructions", type="str"),
|
||||
],
|
||||
fn = update_description,
|
||||
multiple=False
|
||||
fn=update_description,
|
||||
multiple=False,
|
||||
),
|
||||
],
|
||||
|
||||
max_calls = self.character_progression_max_changes,
|
||||
|
||||
max_calls=self.character_progression_max_changes,
|
||||
# context
|
||||
character = character,
|
||||
scene = self.scene,
|
||||
instructions = instructions,
|
||||
character=character,
|
||||
scene=self.scene,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
|
||||
await focal_handler.request(
|
||||
"world_state.determine-character-development",
|
||||
)
|
||||
|
||||
|
||||
log.debug("determine_character_development", calls=focal_handler.state.calls)
|
||||
|
||||
return focal_handler.state.calls
|
||||
|
||||
return focal_handler.state.calls
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import structlog
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
from talemate.context import active_scene
|
||||
from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES
|
||||
from talemate.game.engine.nodes.core import (
|
||||
GraphState,
|
||||
PropertyField,
|
||||
UNRESOLVED,
|
||||
TYPE_CHOICES,
|
||||
)
|
||||
from talemate.game.engine.nodes.registry import register
|
||||
from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode
|
||||
from talemate.world_state import InsertionMode
|
||||
@@ -10,140 +15,140 @@ from talemate.world_state.manager import WorldStateManager
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Scene, Character
|
||||
|
||||
TYPE_CHOICES.extend([
|
||||
"world_state/reinforcement",
|
||||
])
|
||||
TYPE_CHOICES.extend(
|
||||
[
|
||||
"world_state/reinforcement",
|
||||
]
|
||||
)
|
||||
|
||||
log = structlog.get_logger("talemate.game.engine.nodes.agents.world_state")
|
||||
|
||||
|
||||
@register("agents/world_state/Settings")
|
||||
class WorldstateSettings(AgentSettingsNode):
|
||||
"""
|
||||
Base node to render world_state agent settings.
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "world_state"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "world_state"
|
||||
|
||||
def __init__(self, title="Worldstate Settings", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
|
||||
@register("agents/world_state/ExtractCharacterSheet")
|
||||
class ExtractCharacterSheet(AgentNode):
|
||||
"""
|
||||
Attempts to extract an attribute based character sheet
|
||||
from a given context for a specific character.
|
||||
|
||||
|
||||
Additionally alteration instructions can be given to
|
||||
modify the character's existing sheet.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- character_name: The name of the character to extract the sheet for
|
||||
- context: The context to extract the sheet from
|
||||
- alteration_instructions: Instructions to alter the character's sheet
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- character_sheet: The extracted character sheet (dict)
|
||||
"""
|
||||
_agent_name:ClassVar[str] = "world_state"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "world_state"
|
||||
|
||||
def __init__(self, title="Extract Character Sheet", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character_name", socket_type="str")
|
||||
self.add_input("context", socket_type="str")
|
||||
self.add_input("alteration_instructions", socket_type="str", optional=True)
|
||||
|
||||
|
||||
self.add_output("character_sheet", socket_type="dict")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
|
||||
context = self.require_input("context")
|
||||
character_name = self.require_input("character_name")
|
||||
alteration_instructions = self.get_input_value("alteration_instructions")
|
||||
|
||||
|
||||
sheet = await self.agent.extract_character_sheet(
|
||||
name = character_name,
|
||||
text = context,
|
||||
alteration_instructions = alteration_instructions
|
||||
name=character_name,
|
||||
text=context,
|
||||
alteration_instructions=alteration_instructions,
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"character_sheet": sheet
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"character_sheet": sheet})
|
||||
|
||||
|
||||
@register("agents/world_state/StateReinforcement")
|
||||
class StateReinforcement(AgentNode):
|
||||
"""
|
||||
Reinforces the a tracked state of a character or the world in general.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- query_or_detail: The query or instruction to reinforce
|
||||
- character: The character to reinforce the state for (optional)
|
||||
|
||||
|
||||
Properties
|
||||
|
||||
|
||||
- reset: If the state should be reset
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- state: graph state
|
||||
- message: state reinforcement message
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "world_state"
|
||||
|
||||
class Fields:
|
||||
|
||||
_agent_name: ClassVar[str] = "world_state"
|
||||
|
||||
class Fields:
|
||||
query_or_detail = PropertyField(
|
||||
name="query_or_detail",
|
||||
description="Query or detail to reinforce",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
|
||||
|
||||
instructions = PropertyField(
|
||||
name="instructions",
|
||||
description="Instructions for the reinforcement",
|
||||
type="text",
|
||||
default=""
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
interval = PropertyField(
|
||||
name="interval",
|
||||
description="Interval for reinforcement",
|
||||
type="int",
|
||||
default=10,
|
||||
min=1,
|
||||
step=1
|
||||
step=1,
|
||||
)
|
||||
|
||||
|
||||
insert_method = PropertyField(
|
||||
name="insert_method",
|
||||
description="Method to insert reinforcement",
|
||||
type="str",
|
||||
default="sequential",
|
||||
choices=[
|
||||
mode.value for mode in InsertionMode
|
||||
]
|
||||
choices=[mode.value for mode in InsertionMode],
|
||||
)
|
||||
|
||||
|
||||
reset = PropertyField(
|
||||
name="reset",
|
||||
description="If the state should be reset",
|
||||
type="bool",
|
||||
default=False
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __init__(self, title="State Reinforcement", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("query_or_detail", socket_type="str")
|
||||
@@ -155,20 +160,20 @@ class StateReinforcement(AgentNode):
|
||||
self.set_property("interval", 10)
|
||||
self.set_property("insert_method", "sequential")
|
||||
self.set_property("reset", False)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("message", socket_type="message")
|
||||
self.add_output("reinforcement", socket_type="world_state/reinforcement")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
scene:"Scene" = active_scene.get()
|
||||
scene: "Scene" = active_scene.get()
|
||||
query_or_detail = self.require_input("query_or_detail")
|
||||
character = self.normalized_input_value("character")
|
||||
reset = self.get_property("reset")
|
||||
interval = self.require_number_input("interval")
|
||||
instructions = self.normalized_input_value("instructions")
|
||||
insert_method = self.get_property("insert_method")
|
||||
|
||||
|
||||
await scene.world_state.add_reinforcement(
|
||||
question=query_or_detail,
|
||||
character=character.name if character else None,
|
||||
@@ -176,142 +181,128 @@ class StateReinforcement(AgentNode):
|
||||
interval=interval,
|
||||
insert=insert_method,
|
||||
)
|
||||
|
||||
|
||||
message = await self.agent.update_reinforcement(
|
||||
question=query_or_detail,
|
||||
character=character,
|
||||
reset=reset
|
||||
question=query_or_detail, character=character, reset=reset
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"message": message
|
||||
})
|
||||
|
||||
self.set_output_values({"state": state, "message": message})
|
||||
|
||||
|
||||
@register("agents/world_state/DeactivateCharacter")
|
||||
class DeactivateCharacter(AgentNode):
|
||||
"""
|
||||
Deactivates a character from the world state.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- character: The character to deactivate
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- state: The updated state
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "world_state"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "world_state"
|
||||
|
||||
def __init__(self, title="Deactivate Character", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("character", socket_type="character")
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
character:"Character" = self.require_input("character")
|
||||
scene:"Scene" = active_scene.get()
|
||||
|
||||
manager:WorldStateManager = scene.world_state_manager
|
||||
|
||||
character: "Character" = self.require_input("character")
|
||||
scene: "Scene" = active_scene.get()
|
||||
|
||||
manager: WorldStateManager = scene.world_state_manager
|
||||
|
||||
manager.deactivate_character(character.name)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state
|
||||
})
|
||||
|
||||
|
||||
self.set_output_values({"state": state})
|
||||
|
||||
|
||||
@register("agents/world_state/EvaluateQuery")
|
||||
class EvaluateQuery(AgentNode):
|
||||
"""
|
||||
Evaluates a query on the world state.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
- query: The query to evaluate
|
||||
- context: The context to evaluate the query in
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- state: The current state
|
||||
- result: The result of the query
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "world_state"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "world_state"
|
||||
|
||||
class Fields:
|
||||
|
||||
query = PropertyField(
|
||||
name="query",
|
||||
description="The query to evaluate",
|
||||
type="str",
|
||||
default=UNRESOLVED
|
||||
default=UNRESOLVED,
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, title="Evaluate Query", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
self.add_input("query", socket_type="str")
|
||||
self.add_input("context", socket_type="str")
|
||||
|
||||
|
||||
self.set_property("query", UNRESOLVED)
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("result", socket_type="bool")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
query = self.require_input("query")
|
||||
context = self.require_input("context")
|
||||
|
||||
result = await self.agent.answer_query_true_or_false(
|
||||
query=query,
|
||||
text=context
|
||||
)
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"result": result
|
||||
})
|
||||
|
||||
|
||||
result = await self.agent.answer_query_true_or_false(query=query, text=context)
|
||||
|
||||
self.set_output_values({"state": state, "result": result})
|
||||
|
||||
|
||||
@register("agents/world_state/RequestWorldState")
|
||||
class RequestWorldState(AgentNode):
|
||||
"""
|
||||
Requests the current world state.
|
||||
|
||||
|
||||
Inputs:
|
||||
|
||||
|
||||
- state: The current state of the graph
|
||||
|
||||
|
||||
Outputs:
|
||||
|
||||
|
||||
- state: The current state
|
||||
- world_state: The current world state
|
||||
"""
|
||||
|
||||
_agent_name:ClassVar[str] = "world_state"
|
||||
|
||||
|
||||
_agent_name: ClassVar[str] = "world_state"
|
||||
|
||||
def __init__(self, title="Request World State", **kwargs):
|
||||
super().__init__(title=title, **kwargs)
|
||||
|
||||
|
||||
def setup(self):
|
||||
self.add_input("state")
|
||||
|
||||
|
||||
self.add_output("state")
|
||||
self.add_output("world_state", socket_type="dict")
|
||||
|
||||
|
||||
async def run(self, state: GraphState):
|
||||
scene:"Scene" = active_scene.get()
|
||||
scene: "Scene" = active_scene.get()
|
||||
world_state = await scene.world_state.request_update()
|
||||
|
||||
self.set_output_values({
|
||||
"state": state,
|
||||
"world_state": world_state
|
||||
})
|
||||
|
||||
self.set_output_values({"state": state, "world_state": world_state})
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Union
|
||||
from talemate.instance import get_agent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from talemate.tale_mate import Actor, Character, Scene
|
||||
from talemate.tale_mate import Character, Scene
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -46,7 +46,6 @@ async def activate_character(scene: "Scene", character: Union[str, "Character"])
|
||||
if isinstance(character, str):
|
||||
character = scene.get_character(character)
|
||||
|
||||
|
||||
if character.name not in scene.inactive_characters:
|
||||
# already activated
|
||||
return False
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import os
|
||||
|
||||
import talemate.client.runpod
|
||||
from talemate.client.anthropic import AnthropicClient
|
||||
from talemate.client.base import ClientBase, ClientDisabledError
|
||||
from talemate.client.cohere import CohereClient
|
||||
from talemate.client.deepseek import DeepSeekClient
|
||||
from talemate.client.google import GoogleClient
|
||||
from talemate.client.groq import GroqClient
|
||||
from talemate.client.koboldcpp import KoboldCppClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.ollama import OllamaClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.openrouter import OpenRouterClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.tabbyapi import TabbyAPIClient
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
import talemate.client.runpod # noqa: F401
|
||||
from talemate.client.anthropic import AnthropicClient # noqa: F401
|
||||
from talemate.client.base import ClientBase, ClientDisabledError # noqa: F401
|
||||
from talemate.client.cohere import CohereClient # noqa: F401
|
||||
from talemate.client.deepseek import DeepSeekClient # noqa: F401
|
||||
from talemate.client.google import GoogleClient # noqa: F401
|
||||
from talemate.client.groq import GroqClient # noqa: F401
|
||||
from talemate.client.koboldcpp import KoboldCppClient # noqa: F401
|
||||
from talemate.client.lmstudio import LMStudioClient # noqa: F401
|
||||
from talemate.client.mistral import MistralAIClient # noqa: F401
|
||||
from talemate.client.ollama import OllamaClient # noqa: F401
|
||||
from talemate.client.openai import OpenAIClient # noqa: F401
|
||||
from talemate.client.openrouter import OpenRouterClient # noqa: F401
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient # noqa: F401
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register # noqa: F401
|
||||
from talemate.client.tabbyapi import TabbyAPIClient # noqa: F401
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient # noqa: F401
|
||||
|
||||
@@ -39,6 +39,7 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
model: str = "claude-3-5-sonnet-latest"
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
@@ -118,13 +119,13 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
self.current_status = status
|
||||
|
||||
data={
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"double_coercion": self.double_coercion,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
data.update(self._common_status_data())
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
@@ -135,7 +136,10 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None):
|
||||
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
|
||||
if (
|
||||
not self.anthropic_api_key
|
||||
and not self.endpoint_override_base_url_configured
|
||||
):
|
||||
self.client = AsyncAnthropic(api_key="sk-1111")
|
||||
log.error("No anthropic API key set")
|
||||
if self.api_key_status:
|
||||
@@ -175,10 +179,10 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
@@ -208,17 +212,18 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
if not self.anthropic_api_key and not self.endpoint_override_base_url_configured:
|
||||
if (
|
||||
not self.anthropic_api_key
|
||||
and not self.endpoint_override_base_url_configured
|
||||
):
|
||||
raise Exception("No anthropic API key set")
|
||||
|
||||
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
|
||||
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt.strip()}
|
||||
]
|
||||
|
||||
|
||||
messages = [{"role": "user", "content": prompt.strip()}]
|
||||
|
||||
if coercion_prompt:
|
||||
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
|
||||
|
||||
@@ -228,7 +233,7 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
@@ -240,22 +245,20 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
stream=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
|
||||
response = ""
|
||||
|
||||
|
||||
async for event in stream:
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
content = event.delta.text
|
||||
response += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
|
||||
elif event.type == "message_start":
|
||||
prompt_tokens = event.message.usage.input_tokens
|
||||
|
||||
|
||||
elif event.type == "message_delta":
|
||||
completion_tokens += event.usage.output_tokens
|
||||
|
||||
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
@@ -267,5 +270,5 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase):
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="anthropic API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -42,6 +42,7 @@ STOPPING_STRINGS = ["<|im_end|>", "</s>"]
|
||||
# disable smart quotes until text rendering is refactored
|
||||
REPLACE_SMART_QUOTES = True
|
||||
|
||||
|
||||
class ClientDisabledError(OSError):
|
||||
def __init__(self, client: "ClientBase"):
|
||||
self.client = client
|
||||
@@ -76,6 +77,7 @@ class CommonDefaults(pydantic.BaseModel):
|
||||
data_format: Literal["yaml", "json"] | None = None
|
||||
preset_group: str | None = None
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
api_url: str = "http://localhost:5000"
|
||||
max_token_length: int = 8192
|
||||
@@ -88,6 +90,7 @@ class FieldGroup(pydantic.BaseModel):
|
||||
description: str
|
||||
icon: str = "mdi-cog"
|
||||
|
||||
|
||||
class ExtraField(pydantic.BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
@@ -97,6 +100,7 @@ class ExtraField(pydantic.BaseModel):
|
||||
group: FieldGroup | None = None
|
||||
note: ux_schema.Note | None = None
|
||||
|
||||
|
||||
class ParameterReroute(pydantic.BaseModel):
|
||||
talemate_parameter: str
|
||||
client_parameter: str
|
||||
@@ -117,23 +121,23 @@ class RequestInformation(pydantic.BaseModel):
|
||||
start_time: float = pydantic.Field(default_factory=time.time)
|
||||
end_time: float | None = None
|
||||
tokens: int = 0
|
||||
|
||||
|
||||
@pydantic.computed_field(description="Duration")
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
end_time = self.end_time or time.time()
|
||||
return end_time - self.start_time
|
||||
|
||||
|
||||
@pydantic.computed_field(description="Tokens per second")
|
||||
@property
|
||||
def rate(self) -> float:
|
||||
try:
|
||||
end_time = self.end_time or time.time()
|
||||
return self.tokens / (end_time - self.start_time)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
@pydantic.computed_field(description="Status")
|
||||
@property
|
||||
def status(self) -> str:
|
||||
@@ -145,7 +149,7 @@ class RequestInformation(pydantic.BaseModel):
|
||||
return "in progress"
|
||||
else:
|
||||
return "pending"
|
||||
|
||||
|
||||
@pydantic.computed_field(description="Age")
|
||||
@property
|
||||
def age(self) -> float:
|
||||
@@ -159,10 +163,12 @@ class ClientEmbeddingsStatus:
|
||||
client: "ClientBase | None" = None
|
||||
embedding_name: str | None = None
|
||||
|
||||
|
||||
async_signals.register(
|
||||
"client.embeddings_available",
|
||||
)
|
||||
|
||||
|
||||
class ClientBase:
|
||||
api_url: str
|
||||
model_name: str
|
||||
@@ -183,10 +189,10 @@ class ClientBase:
|
||||
rate_limit: int | None = None
|
||||
client_type = "base"
|
||||
request_information: RequestInformation | None = None
|
||||
|
||||
status_request_timeout:int = 2
|
||||
|
||||
system_prompts:SystemPrompts = SystemPrompts()
|
||||
|
||||
status_request_timeout: int = 2
|
||||
|
||||
system_prompts: SystemPrompts = SystemPrompts()
|
||||
preset_group: str | None = ""
|
||||
|
||||
rate_limit_counter: CounterRateLimiter = None
|
||||
@@ -216,7 +222,7 @@ class ClientBase:
|
||||
self.max_token_length = (
|
||||
int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192
|
||||
)
|
||||
|
||||
|
||||
self.set_client(max_token_length=self.max_token_length)
|
||||
|
||||
def __str__(self):
|
||||
@@ -252,23 +258,23 @@ class ClientBase:
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
]
|
||||
|
||||
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_function(self):
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_status(self) -> bool:
|
||||
return getattr(self, "_embeddings_status", False)
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_model_name(self) -> str | None:
|
||||
return getattr(self, "_embeddings_model_name", None)
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_url(self) -> str:
|
||||
return None
|
||||
@@ -277,16 +283,16 @@ class ClientBase:
|
||||
def embeddings_identifier(self) -> str:
|
||||
return f"client-api/{self.name}/{self.embeddings_model_name}"
|
||||
|
||||
async def destroy(self, config:dict):
|
||||
async def destroy(self, config: dict):
|
||||
"""
|
||||
This is called before the client is removed from talemate.instance.clients
|
||||
|
||||
|
||||
Use this to perform any cleanup that is necessary.
|
||||
|
||||
|
||||
If a subclass overrides this method, it should call super().destroy(config) in the
|
||||
end of the method.
|
||||
"""
|
||||
|
||||
|
||||
if self.supports_embeddings:
|
||||
self.remove_embeddings(config)
|
||||
|
||||
@@ -296,25 +302,28 @@ class ClientBase:
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
|
||||
|
||||
|
||||
def set_embeddings(self):
|
||||
|
||||
log.debug("setting embeddings", client=self.name, supports_embeddings=self.supports_embeddings, embeddings_status=self.embeddings_status)
|
||||
|
||||
log.debug(
|
||||
"setting embeddings",
|
||||
client=self.name,
|
||||
supports_embeddings=self.supports_embeddings,
|
||||
embeddings_status=self.embeddings_status,
|
||||
)
|
||||
|
||||
if not self.supports_embeddings or not self.embeddings_status:
|
||||
return
|
||||
|
||||
|
||||
config = load_config(as_model=True)
|
||||
|
||||
|
||||
key = self.embeddings_identifier
|
||||
|
||||
|
||||
if key in config.presets.embeddings:
|
||||
log.debug("embeddings already set", client=self.name, key=key)
|
||||
return config.presets.embeddings[key]
|
||||
|
||||
|
||||
|
||||
log.debug("setting embeddings", client=self.name, key=key)
|
||||
|
||||
|
||||
config.presets.embeddings[key] = EmbeddingFunctionPreset(
|
||||
embeddings="client-api",
|
||||
client=self.name,
|
||||
@@ -324,10 +333,10 @@ class ClientBase:
|
||||
local=False,
|
||||
custom=True,
|
||||
)
|
||||
|
||||
|
||||
save_config(config)
|
||||
|
||||
def remove_embeddings(self, config:dict | None = None):
|
||||
|
||||
def remove_embeddings(self, config: dict | None = None):
|
||||
# remove all embeddings for this client
|
||||
for key, value in list(config["presets"]["embeddings"].items()):
|
||||
if value["client"] == self.name and value["embeddings"] == "client-api":
|
||||
@@ -338,7 +347,9 @@ class ClientBase:
|
||||
if isinstance(system_prompts, dict):
|
||||
self.system_prompts = SystemPrompts(**system_prompts)
|
||||
elif not isinstance(system_prompts, SystemPrompts):
|
||||
raise ValueError("system_prompts must be a `dict` or `SystemPrompts` instance")
|
||||
raise ValueError(
|
||||
"system_prompts must be a `dict` or `SystemPrompts` instance"
|
||||
)
|
||||
else:
|
||||
self.system_prompts = system_prompts
|
||||
|
||||
@@ -376,10 +387,10 @@ class ClientBase:
|
||||
"""
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
|
||||
|
||||
if self.double_coercion:
|
||||
right = f"{self.double_coercion}\n\n{right}"
|
||||
|
||||
|
||||
return prompt, right
|
||||
return prompt, None
|
||||
|
||||
@@ -407,7 +418,7 @@ class ClientBase:
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def _reconfigure_common_parameters(self, **kwargs):
|
||||
@@ -415,12 +426,14 @@ class ClientBase:
|
||||
self.rate_limit = kwargs["rate_limit"]
|
||||
if self.rate_limit:
|
||||
if not self.rate_limit_counter:
|
||||
self.rate_limit_counter = CounterRateLimiter(rate_per_minute=self.rate_limit)
|
||||
self.rate_limit_counter = CounterRateLimiter(
|
||||
rate_per_minute=self.rate_limit
|
||||
)
|
||||
else:
|
||||
self.rate_limit_counter.update_rate_limit(self.rate_limit)
|
||||
else:
|
||||
self.rate_limit_counter = None
|
||||
|
||||
|
||||
if "data_format" in kwargs:
|
||||
self.data_format = kwargs["data_format"]
|
||||
|
||||
@@ -478,19 +491,21 @@ class ClientBase:
|
||||
|
||||
- kind: the kind of generation
|
||||
"""
|
||||
|
||||
app_config_system_prompts = client_context_attribute("app_config_system_prompts")
|
||||
|
||||
|
||||
app_config_system_prompts = client_context_attribute(
|
||||
"app_config_system_prompts"
|
||||
)
|
||||
|
||||
if app_config_system_prompts:
|
||||
self.system_prompts.parent = SystemPrompts(**app_config_system_prompts)
|
||||
|
||||
|
||||
return self.system_prompts.get(kind, self.decensor_enabled)
|
||||
|
||||
def emit_status(self, processing: bool = None):
|
||||
"""
|
||||
Sets and emits the client status.
|
||||
"""
|
||||
|
||||
|
||||
if processing is not None:
|
||||
self.processing = processing
|
||||
|
||||
@@ -516,7 +531,6 @@ class ClientBase:
|
||||
)
|
||||
|
||||
if not has_prompt_template and self.auto_determine_prompt_template:
|
||||
|
||||
# only attempt to determine the prompt template once per model and
|
||||
# only if the model does not already have a prompt template
|
||||
|
||||
@@ -581,15 +595,17 @@ class ClientBase:
|
||||
"supports_embeddings": self.supports_embeddings,
|
||||
"embeddings_status": self.embeddings_status,
|
||||
"embeddings_model_name": self.embeddings_model_name,
|
||||
"request_information": self.request_information.model_dump() if self.request_information else None,
|
||||
"request_information": self.request_information.model_dump()
|
||||
if self.request_information
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
extra_fields = getattr(self.Meta(), "extra_fields", {})
|
||||
for field_name in extra_fields.keys():
|
||||
common_data[field_name] = getattr(self, field_name, None)
|
||||
|
||||
|
||||
return common_data
|
||||
|
||||
|
||||
def populate_extra_fields(self, data: dict):
|
||||
"""
|
||||
Updates data with the extra fields from the client's Meta
|
||||
@@ -665,7 +681,7 @@ class ClientBase:
|
||||
agent_context.agent.inject_prompt_paramters(
|
||||
parameters, kind, agent_context.action
|
||||
)
|
||||
|
||||
|
||||
if client_context_attribute(
|
||||
"nuke_repetition"
|
||||
) > 0.0 and self.jiggle_enabled_for(kind):
|
||||
@@ -707,7 +723,6 @@ class ClientBase:
|
||||
del parameters[key]
|
||||
|
||||
def finalize(self, parameters: dict, prompt: str):
|
||||
|
||||
prompt = util.replace_special_tokens(prompt)
|
||||
|
||||
for finalizer in self.finalizers:
|
||||
@@ -740,22 +755,20 @@ class ClientBase:
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def _generate_task(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Creates an asyncio task to generate text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
|
||||
return asyncio.create_task(self.generate(prompt, parameters, kind))
|
||||
|
||||
|
||||
def _poll_interrupt(self):
|
||||
"""
|
||||
Creatates a task that continiously checks active_scene.cancel_requested and
|
||||
will complete the task if it is requested.
|
||||
"""
|
||||
|
||||
|
||||
async def poll():
|
||||
while True:
|
||||
scene = active_scene.get()
|
||||
@@ -763,58 +776,58 @@ class ClientBase:
|
||||
break
|
||||
await asyncio.sleep(0.3)
|
||||
return GenerationCancelled("Generation cancelled")
|
||||
|
||||
|
||||
return asyncio.create_task(poll())
|
||||
|
||||
async def _cancelable_generate(self, prompt: str, parameters: dict, kind: str) -> str | GenerationCancelled:
|
||||
|
||||
|
||||
async def _cancelable_generate(
|
||||
self, prompt: str, parameters: dict, kind: str
|
||||
) -> str | GenerationCancelled:
|
||||
"""
|
||||
Queues the generation task and the poll task to be run concurrently.
|
||||
|
||||
|
||||
If the poll task completes before the generation task, the generation task
|
||||
will be cancelled.
|
||||
|
||||
|
||||
If the generation task completes before the poll task, the poll task will
|
||||
be cancelled.
|
||||
"""
|
||||
|
||||
|
||||
task_poll = self._poll_interrupt()
|
||||
task_generate = self._generate_task(prompt, parameters, kind)
|
||||
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
[task_poll, task_generate],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
[task_poll, task_generate], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
|
||||
# cancel the remaining task
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
|
||||
# return the result of the completed task
|
||||
return done.pop().result()
|
||||
|
||||
|
||||
async def abort_generation(self):
|
||||
"""
|
||||
This function can be overwritten to trigger an abortion at the other
|
||||
side of the client.
|
||||
|
||||
|
||||
So a generation is cancelled here, this can be used to cancel a generation
|
||||
at the other side of the client.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def new_request(self):
|
||||
"""
|
||||
Creates a new request information object.
|
||||
"""
|
||||
self.request_information = RequestInformation()
|
||||
|
||||
|
||||
def end_request(self):
|
||||
"""
|
||||
Ends the request information object.
|
||||
"""
|
||||
self.request_information.end_time = time.time()
|
||||
|
||||
|
||||
def update_request_tokens(self, tokens: int, replace: bool = False):
|
||||
"""
|
||||
Updates the request information object with the number of tokens received.
|
||||
@@ -824,7 +837,7 @@ class ClientBase:
|
||||
self.request_information.tokens = tokens
|
||||
else:
|
||||
self.request_information.tokens += tokens
|
||||
|
||||
|
||||
async def send_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -837,13 +850,13 @@ class ClientBase:
|
||||
:param prompt: The text prompt to send.
|
||||
:return: The AI's response text.
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
return await self._send_prompt(prompt, kind, finalize, retries)
|
||||
except GenerationCancelled:
|
||||
await self.abort_generation()
|
||||
raise
|
||||
|
||||
|
||||
async def _send_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -856,47 +869,49 @@ class ClientBase:
|
||||
:param prompt: The text prompt to send.
|
||||
:return: The AI's response text.
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
if self.rate_limit_counter:
|
||||
aborted:bool = False
|
||||
aborted: bool = False
|
||||
while not self.rate_limit_counter.increment():
|
||||
log.warn("Rate limit exceeded", client=self.name)
|
||||
emit(
|
||||
"rate_limited",
|
||||
message="Rate limit exceeded",
|
||||
status="error",
|
||||
message="Rate limit exceeded",
|
||||
status="error",
|
||||
websocket_passthrough=True,
|
||||
data={
|
||||
"client": self.name,
|
||||
"rate_limit": self.rate_limit,
|
||||
"reset_time": self.rate_limit_counter.reset_time(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
scene = active_scene.get()
|
||||
if not scene or not scene.active or scene.cancel_requested:
|
||||
log.info("Rate limit exceeded, generation cancelled", client=self.name)
|
||||
log.info(
|
||||
"Rate limit exceeded, generation cancelled",
|
||||
client=self.name,
|
||||
)
|
||||
aborted = True
|
||||
break
|
||||
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
emit(
|
||||
"rate_limit_reset",
|
||||
message="Rate limit reset",
|
||||
status="info",
|
||||
websocket_passthrough=True,
|
||||
data={"client": self.name}
|
||||
data={"client": self.name},
|
||||
)
|
||||
|
||||
|
||||
if aborted:
|
||||
raise GenerationCancelled("Generation cancelled")
|
||||
except GenerationCancelled:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
log.error("Error during rate limit check", e=traceback.format_exc())
|
||||
|
||||
|
||||
if not active_scene.get():
|
||||
log.error("SceneInactiveError", scene=active_scene.get())
|
||||
@@ -940,25 +955,25 @@ class ClientBase:
|
||||
parameters=prompt_param,
|
||||
)
|
||||
prompt_sent = self.repetition_adjustment(finalized_prompt)
|
||||
|
||||
|
||||
self.new_request()
|
||||
|
||||
|
||||
response = await self._cancelable_generate(prompt_sent, prompt_param, kind)
|
||||
|
||||
|
||||
self.end_request()
|
||||
|
||||
|
||||
if isinstance(response, GenerationCancelled):
|
||||
# generation was cancelled
|
||||
raise response
|
||||
|
||||
#response = await self.generate(prompt_sent, prompt_param, kind)
|
||||
|
||||
|
||||
# response = await self.generate(prompt_sent, prompt_param, kind)
|
||||
|
||||
response, finalized_prompt = await self.auto_break_repetition(
|
||||
finalized_prompt, prompt_param, response, kind, retries
|
||||
)
|
||||
|
||||
|
||||
if REPLACE_SMART_QUOTES:
|
||||
response = response.replace('“', '"').replace('”', '"')
|
||||
response = response.replace("“", '"').replace("”", '"')
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
@@ -992,9 +1007,9 @@ class ClientBase:
|
||||
)
|
||||
|
||||
return response
|
||||
except GenerationCancelled as e:
|
||||
except GenerationCancelled:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
self.log.error("send_prompt error", e=traceback.format_exc())
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
@@ -1004,7 +1019,7 @@ class ClientBase:
|
||||
self.emit_status(processing=False)
|
||||
self._returned_prompt_tokens = None
|
||||
self._returned_response_tokens = None
|
||||
|
||||
|
||||
if self.rate_limit_counter:
|
||||
self.rate_limit_counter.increment()
|
||||
|
||||
|
||||
@@ -2,7 +2,13 @@ import pydantic
|
||||
import structlog
|
||||
from cohere import AsyncClientV2
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
|
||||
from talemate.client.base import (
|
||||
ClientBase,
|
||||
ErrorAction,
|
||||
ParameterReroute,
|
||||
CommonDefaults,
|
||||
ExtraField,
|
||||
)
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
@@ -51,7 +57,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
auto_break_repetition_enabled = False
|
||||
decensor_enabled = True
|
||||
config_cls = ClientConfig
|
||||
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Cohere"
|
||||
title: str = "Cohere"
|
||||
@@ -115,12 +121,12 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
self.current_status = status
|
||||
|
||||
data={
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
data.update(self._common_status_data())
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
@@ -171,7 +177,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
@@ -200,7 +206,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
return prompt
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
|
||||
super().clean_prompt_parameters(parameters)
|
||||
|
||||
# if temperature is set, it needs to be clamped between 0 and 1.0
|
||||
@@ -240,7 +245,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
parameters=parameters,
|
||||
system_message=system_message,
|
||||
)
|
||||
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -249,7 +254,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
{
|
||||
"role": "user",
|
||||
"content": human_message,
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
@@ -290,5 +295,5 @@ class CohereClient(EndpointOverrideMixin, ClientBase):
|
||||
# self.log.error("generate error", e=e)
|
||||
# emit("status", message="cohere API: Permission Denied", status="error")
|
||||
# return ""
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -23,7 +23,10 @@ def model_to_dict_without_defaults(model_instance):
|
||||
if field.default == model_dict.get(field_name):
|
||||
del model_dict[field_name]
|
||||
# special case for conversation context, dont copy if talking_character is None
|
||||
if field_name == "conversation" and model_dict.get(field_name).get("talking_character") is None:
|
||||
if (
|
||||
field_name == "conversation"
|
||||
and model_dict.get(field_name).get("talking_character") is None
|
||||
):
|
||||
del model_dict[field_name]
|
||||
return model_dict
|
||||
|
||||
@@ -102,7 +105,7 @@ class ClientContext:
|
||||
|
||||
# Update the context data
|
||||
self.token = context_data.set(data)
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Reset the context variable `context_data` to its previous values when exiting the context.
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import json
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
import tiktoken
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, CommonDefaults
|
||||
@@ -103,12 +100,12 @@ class DeepSeekClient(ClientBase):
|
||||
|
||||
self.current_status = status
|
||||
|
||||
data={
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
data.update(self._common_status_data())
|
||||
emit(
|
||||
"client_status",
|
||||
message=self.client_type,
|
||||
@@ -273,5 +270,5 @@ class DeepSeekClient(ClientBase):
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="DeepSeek API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -7,7 +7,13 @@ from google import genai
|
||||
import google.genai.types as genai_types
|
||||
from google.genai.errors import APIError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute, CommonDefaults
|
||||
from talemate.client.base import (
|
||||
ClientBase,
|
||||
ErrorAction,
|
||||
ExtraField,
|
||||
ParameterReroute,
|
||||
CommonDefaults,
|
||||
)
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.remote import (
|
||||
RemoteServiceMixin,
|
||||
@@ -41,12 +47,14 @@ SUPPORTED_MODELS = [
|
||||
"gemini-2.5-pro-preview-06-05",
|
||||
]
|
||||
|
||||
|
||||
class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gemini-2.0-flash"
|
||||
disable_safety_settings: bool = False
|
||||
double_coercion: str = None
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
disable_safety_settings: bool = False
|
||||
|
||||
@@ -80,7 +88,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
),
|
||||
}
|
||||
extra_fields.update(endpoint_override_extra_fields())
|
||||
|
||||
|
||||
def __init__(self, model="gemini-2.0-flash", **kwargs):
|
||||
self.model_name = model
|
||||
@@ -114,24 +121,28 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
@property
|
||||
def google_location(self):
|
||||
return self.config.get("google").get("gcloud_location")
|
||||
|
||||
|
||||
@property
|
||||
def google_api_key(self):
|
||||
return self.config.get("google").get("api_key")
|
||||
|
||||
@property
|
||||
def vertexai_ready(self) -> bool:
|
||||
return all([
|
||||
self.google_credentials_path,
|
||||
self.google_location,
|
||||
])
|
||||
return all(
|
||||
[
|
||||
self.google_credentials_path,
|
||||
self.google_location,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def developer_api_ready(self) -> bool:
|
||||
return all([
|
||||
self.google_api_key,
|
||||
])
|
||||
|
||||
return all(
|
||||
[
|
||||
self.google_api_key,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def using(self) -> str:
|
||||
if self.developer_api_ready:
|
||||
@@ -143,7 +154,11 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
@property
|
||||
def ready(self):
|
||||
# all google settings must be set
|
||||
return self.vertexai_ready or self.developer_api_ready or self.endpoint_override_base_url_configured
|
||||
return (
|
||||
self.vertexai_ready
|
||||
or self.developer_api_ready
|
||||
or self.endpoint_override_base_url_configured
|
||||
)
|
||||
|
||||
@property
|
||||
def safety_settings(self):
|
||||
@@ -179,10 +194,8 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
def http_options(self) -> genai_types.HttpOptions | None:
|
||||
if not self.endpoint_override_base_url_configured:
|
||||
return None
|
||||
|
||||
return genai_types.HttpOptions(
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
return genai_types.HttpOptions(base_url=self.base_url)
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
@@ -230,7 +243,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
data.update(self._common_status_data())
|
||||
self.populate_extra_fields(data)
|
||||
|
||||
if self.using == "VertexAI":
|
||||
@@ -252,7 +265,9 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
try:
|
||||
self.client.http_options.base_url = base_url
|
||||
except Exception as e:
|
||||
log.error("Error setting client base URL", error=e, client=self.client_type)
|
||||
log.error(
|
||||
"Error setting client base URL", error=e, client=self.client_type
|
||||
)
|
||||
|
||||
def set_client(self, max_token_length: int = None, **kwargs):
|
||||
if not self.ready:
|
||||
@@ -283,7 +298,9 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
location=self.google_location,
|
||||
)
|
||||
else:
|
||||
self.client = genai.Client(api_key=self.api_key or None, http_options=self.http_options)
|
||||
self.client = genai.Client(
|
||||
api_key=self.api_key or None, http_options=self.http_options
|
||||
)
|
||||
|
||||
log.info(
|
||||
"google set client",
|
||||
@@ -292,7 +309,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
model=model,
|
||||
)
|
||||
|
||||
def response_tokens(self, response:str):
|
||||
def response_tokens(self, response: str):
|
||||
"""Return token count for a response which may be a string or SDK object."""
|
||||
return count_tokens(response)
|
||||
|
||||
@@ -309,10 +326,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
def clean_prompt_parameters(self, parameters: dict):
|
||||
@@ -329,7 +346,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -342,7 +358,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
|
||||
human_message = prompt.strip()
|
||||
system_message = self.get_system_message(kind)
|
||||
|
||||
|
||||
contents = [
|
||||
genai_types.Content(
|
||||
role="user",
|
||||
@@ -350,10 +366,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
genai_types.Part.from_text(
|
||||
text=human_message,
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
if coercion_prompt:
|
||||
contents.append(
|
||||
genai_types.Content(
|
||||
@@ -362,7 +378,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
genai_types.Part.from_text(
|
||||
text=coercion_prompt,
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -378,24 +394,23 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
|
||||
try:
|
||||
# Use streaming so we can update_Request_tokens incrementally
|
||||
#stream = await chat.send_message_async(
|
||||
# stream = await chat.send_message_async(
|
||||
# human_message,
|
||||
# safety_settings=self.safety_settings,
|
||||
# generation_config=parameters,
|
||||
# stream=True
|
||||
#)
|
||||
# )
|
||||
|
||||
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model_name,
|
||||
contents=contents,
|
||||
config=genai_types.GenerateContentConfig(
|
||||
safety_settings=self.safety_settings,
|
||||
http_options=self.http_options,
|
||||
**parameters
|
||||
**parameters,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
response = ""
|
||||
|
||||
async for chunk in stream:
|
||||
@@ -425,5 +440,5 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase):
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="google API: API Error", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -260,5 +260,5 @@ class GroqClient(EndpointOverrideMixin, ClientBase):
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="OpenAI API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -17,7 +17,7 @@ from talemate.client.base import (
|
||||
ClientBase,
|
||||
Defaults,
|
||||
ParameterReroute,
|
||||
ClientEmbeddingsStatus
|
||||
ClientEmbeddingsStatus,
|
||||
)
|
||||
from talemate.client.registry import register
|
||||
import talemate.emit.async_signals as async_signals
|
||||
@@ -46,9 +46,13 @@ class KoboldEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
Embed a list of input texts using the KoboldCPP embeddings endpoint.
|
||||
"""
|
||||
|
||||
log.debug("KoboldCppEmbeddingFunction", api_url=self.api_url, model_name=self.model_name)
|
||||
|
||||
|
||||
log.debug(
|
||||
"KoboldCppEmbeddingFunction",
|
||||
api_url=self.api_url,
|
||||
model_name=self.model_name,
|
||||
)
|
||||
|
||||
# Prepare the request payload for KoboldCPP. Include model name if required.
|
||||
payload = {"input": texts}
|
||||
if self.model_name is not None:
|
||||
@@ -65,6 +69,8 @@ class KoboldEmbeddingFunction(EmbeddingFunction):
|
||||
embeddings = [item["embedding"] for item in embedding_results]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@register()
|
||||
class KoboldCppClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
@@ -147,7 +153,6 @@ class KoboldCppClient(ClientBase):
|
||||
talemate_parameter="stopping_strings",
|
||||
client_parameter="stop_sequence",
|
||||
),
|
||||
|
||||
"xtc_threshold",
|
||||
"xtc_probability",
|
||||
"dry_multiplier",
|
||||
@@ -155,7 +160,6 @@ class KoboldCppClient(ClientBase):
|
||||
"dry_allowed_length",
|
||||
"dry_sequence_breakers",
|
||||
"smoothing_factor",
|
||||
|
||||
"temperature",
|
||||
]
|
||||
|
||||
@@ -172,18 +176,18 @@ class KoboldCppClient(ClientBase):
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_url(self) -> str:
|
||||
if self.is_openai:
|
||||
return urljoin(self.api_url, "embeddings")
|
||||
else:
|
||||
return urljoin(self.api_url, "api/extra/embeddings")
|
||||
|
||||
|
||||
@property
|
||||
def embeddings_function(self):
|
||||
return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name)
|
||||
|
||||
|
||||
def api_endpoint_specified(self, url: str) -> bool:
|
||||
return "/v1" in self.api_url
|
||||
|
||||
@@ -208,10 +212,10 @@ class KoboldCppClient(ClientBase):
|
||||
# if self._embeddings_model_name is set, return it
|
||||
if self.embeddings_model_name:
|
||||
return self.embeddings_model_name
|
||||
|
||||
|
||||
# otherwise, get the model name by doing a request to
|
||||
# the embeddings endpoint with a single character
|
||||
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.embeddings_url,
|
||||
@@ -219,37 +223,40 @@ class KoboldCppClient(ClientBase):
|
||||
timeout=2,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
|
||||
response_data = response.json()
|
||||
self._embeddings_model_name = response_data.get("model")
|
||||
return self._embeddings_model_name
|
||||
|
||||
|
||||
async def get_embeddings_status(self):
|
||||
url_version = urljoin(self.api_url, "api/extra/version")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url_version, timeout=2)
|
||||
response_data = response.json()
|
||||
self._embeddings_status = response_data.get("embeddings", False)
|
||||
|
||||
|
||||
if not self.embeddings_status or self.embeddings_model_name:
|
||||
return
|
||||
|
||||
|
||||
await self.get_embeddings_model_name()
|
||||
|
||||
log.debug("KoboldCpp embeddings are enabled, suggesting embeddings", model_name=self.embeddings_model_name)
|
||||
|
||||
|
||||
log.debug(
|
||||
"KoboldCpp embeddings are enabled, suggesting embeddings",
|
||||
model_name=self.embeddings_model_name,
|
||||
)
|
||||
|
||||
self.set_embeddings()
|
||||
|
||||
|
||||
await async_signals.get("client.embeddings_available").send(
|
||||
ClientEmbeddingsStatus(
|
||||
client=self,
|
||||
embedding_name=self.embeddings_model_name,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def get_model_name(self):
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
@@ -275,7 +282,7 @@ class KoboldCppClient(ClientBase):
|
||||
# split by "/" and take last
|
||||
if model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
|
||||
await self.get_embeddings_status()
|
||||
|
||||
return model_name
|
||||
@@ -309,7 +316,6 @@ class KoboldCppClient(ClientBase):
|
||||
tokencount = len(response.json().get("ids", []))
|
||||
return tokencount
|
||||
|
||||
|
||||
async def abort_generation(self):
|
||||
"""
|
||||
Trigger the stop generation endpoint
|
||||
@@ -317,7 +323,7 @@ class KoboldCppClient(ClientBase):
|
||||
if self.is_openai:
|
||||
# openai api endpoint doesn't support abort
|
||||
return
|
||||
|
||||
|
||||
parts = urlparse(self.api_url)
|
||||
url_abort = f"{parts.scheme}://{parts.netloc}/api/extra/abort"
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -325,7 +331,7 @@ class KoboldCppClient(ClientBase):
|
||||
url_abort,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
@@ -334,14 +340,16 @@ class KoboldCppClient(ClientBase):
|
||||
return await self._generate_openai(prompt, parameters, kind)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._generate_kcpp_stream, prompt, parameters, kind)
|
||||
|
||||
return await loop.run_in_executor(
|
||||
None, self._generate_kcpp_stream, prompt, parameters, kind
|
||||
)
|
||||
|
||||
def _generate_kcpp_stream(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
|
||||
response = ""
|
||||
parameters["stream"] = True
|
||||
stream_response = requests.post(
|
||||
@@ -352,15 +360,15 @@ class KoboldCppClient(ClientBase):
|
||||
stream=True,
|
||||
)
|
||||
stream_response.raise_for_status()
|
||||
|
||||
|
||||
sse = sseclient.SSEClient(stream_response)
|
||||
|
||||
|
||||
for event in sse.events():
|
||||
payload = json.loads(event.data)
|
||||
chunk = payload['token']
|
||||
chunk = payload["token"]
|
||||
response += chunk
|
||||
self.update_request_tokens(self.count_tokens(chunk))
|
||||
|
||||
|
||||
return response
|
||||
|
||||
async def _generate_openai(self, prompt: str, parameters: dict, kind: str):
|
||||
@@ -447,7 +455,6 @@ class KoboldCppClient(ClientBase):
|
||||
sd_models_url = urljoin(self.url, "/sdapi/v1/sd-models")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
||||
try:
|
||||
response = await client.get(url=sd_models_url, timeout=2)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -37,13 +37,13 @@ class LMStudioClient(ClientBase):
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
|
||||
if self.client and self.client.base_url != self.api_url:
|
||||
self.set_client()
|
||||
|
||||
async def get_model_name(self):
|
||||
model_name = await super().get_model_name()
|
||||
|
||||
|
||||
# model name comes back as a file path, so we need to extract the model name
|
||||
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import pydantic
|
||||
import structlog
|
||||
from typing import Literal
|
||||
from mistralai import Mistral
|
||||
from mistralai.models.sdkerror import SDKError
|
||||
|
||||
from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField
|
||||
from talemate.client.base import (
|
||||
ClientBase,
|
||||
ErrorAction,
|
||||
CommonDefaults,
|
||||
ExtraField,
|
||||
)
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.remote import (
|
||||
EndpointOverride,
|
||||
@@ -44,9 +48,11 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "open-mixtral-8x22b"
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register()
|
||||
class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
"""
|
||||
@@ -68,7 +74,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
|
||||
def __init__(self, model="open-mixtral-8x22b", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
@@ -115,7 +121,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
model_name = "No model loaded"
|
||||
|
||||
self.current_status = status
|
||||
data={
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
@@ -167,10 +173,10 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
def reconfigure(self, **kwargs):
|
||||
if "enabled" in kwargs:
|
||||
self.enabled = bool(kwargs["enabled"])
|
||||
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
self._reconfigure_endpoint_override(**kwargs)
|
||||
|
||||
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
@@ -248,14 +254,16 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
)
|
||||
|
||||
response = ""
|
||||
|
||||
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
async for event in event_stream:
|
||||
if event.data.choices:
|
||||
response += event.data.choices[0].delta.content
|
||||
self.update_request_tokens(self.count_tokens(event.data.choices[0].delta.content))
|
||||
self.update_request_tokens(
|
||||
self.count_tokens(event.data.choices[0].delta.content)
|
||||
)
|
||||
if event.data.usage:
|
||||
completion_tokens += event.data.usage.completion_tokens
|
||||
prompt_tokens += event.data.usage.prompt_tokens
|
||||
@@ -263,7 +271,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
|
||||
#response = response.choices[0].message.content
|
||||
# response = response.choices[0].message.content
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
@@ -282,12 +290,12 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase):
|
||||
return response
|
||||
except SDKError as e:
|
||||
self.log.error("generate error", e=e)
|
||||
if hasattr(e, 'status_code') and e.status_code in [403, 401]:
|
||||
if hasattr(e, "status_code") and e.status_code in [403, 401]:
|
||||
emit(
|
||||
"status",
|
||||
message="mistral.ai API: Permission Denied",
|
||||
status="error",
|
||||
)
|
||||
return ""
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -117,7 +117,6 @@ class ModelPrompt:
|
||||
prompt = f"{prompt}<|BOT|>"
|
||||
|
||||
if "<|BOT|>" in prompt:
|
||||
|
||||
response_str = f"{double_coercion}{response_str}"
|
||||
|
||||
if "\n<|BOT|>" in prompt:
|
||||
@@ -264,7 +263,11 @@ class Mistralv7TekkenIdentifier(TemplateIdentifier):
|
||||
template_str = "MistralV7Tekken"
|
||||
|
||||
def __call__(self, content: str):
|
||||
return "[SYSTEM_PROMPT]" in content and "[INST]" in content and "[/INST]" in content
|
||||
return (
|
||||
"[SYSTEM_PROMPT]" in content
|
||||
and "[INST]" in content
|
||||
and "[/INST]" in content
|
||||
)
|
||||
|
||||
|
||||
@register_template_identifier
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import asyncio
|
||||
import structlog
|
||||
import httpx
|
||||
import ollama
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, CommonDefaults, ErrorAction, ParameterReroute, ExtraField
|
||||
from talemate.client.base import (
|
||||
STOPPING_STRINGS,
|
||||
ClientBase,
|
||||
CommonDefaults,
|
||||
ErrorAction,
|
||||
ParameterReroute,
|
||||
ExtraField,
|
||||
)
|
||||
from talemate.client.registry import register
|
||||
from talemate.config import Client as BaseClientConfig
|
||||
|
||||
@@ -14,27 +19,30 @@ log = structlog.get_logger("talemate.client.ollama")
|
||||
|
||||
FETCH_MODELS_INTERVAL = 15
|
||||
|
||||
|
||||
class OllamaClientDefaults(CommonDefaults):
|
||||
api_url: str = "http://localhost:11434" # Default Ollama URL
|
||||
model: str = "" # Allow empty default, will fetch from Ollama
|
||||
api_handles_prompt_template: bool = False
|
||||
allow_thinking: bool = False
|
||||
|
||||
|
||||
class ClientConfig(BaseClientConfig):
|
||||
api_handles_prompt_template: bool = False
|
||||
allow_thinking: bool = False
|
||||
|
||||
|
||||
@register()
|
||||
class OllamaClient(ClientBase):
|
||||
"""
|
||||
Ollama client for generating text using locally hosted models.
|
||||
"""
|
||||
|
||||
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "ollama"
|
||||
conversation_retries = 0
|
||||
config_cls = ClientConfig
|
||||
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "Ollama"
|
||||
title: str = "Ollama"
|
||||
@@ -56,27 +64,26 @@ class OllamaClient(ClientBase):
|
||||
label="Allow thinking",
|
||||
required=False,
|
||||
description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.",
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def supported_parameters(self):
|
||||
# Parameters supported by Ollama's generate endpoint
|
||||
# Based on the API documentation
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
ParameterReroute(
|
||||
talemate_parameter="repetition_penalty",
|
||||
client_parameter="repeat_penalty"
|
||||
client_parameter="repeat_penalty",
|
||||
),
|
||||
ParameterReroute(
|
||||
talemate_parameter="max_tokens",
|
||||
client_parameter="num_predict"
|
||||
talemate_parameter="max_tokens", client_parameter="num_predict"
|
||||
),
|
||||
"stopping_strings",
|
||||
# internal parameters that will be removed before sending
|
||||
@@ -98,7 +105,13 @@ class OllamaClient(ClientBase):
|
||||
"""
|
||||
return self.allow_thinking
|
||||
|
||||
def __init__(self, model=None, api_handles_prompt_template=False, allow_thinking=False, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
api_handles_prompt_template=False,
|
||||
allow_thinking=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_name = model
|
||||
self.api_handles_prompt_template = api_handles_prompt_template
|
||||
self.allow_thinking = allow_thinking
|
||||
@@ -114,16 +127,14 @@ class OllamaClient(ClientBase):
|
||||
# Update model if provided
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
|
||||
|
||||
# Create async client with the configured API URL
|
||||
# Ollama's AsyncClient expects just the base URL without any path
|
||||
self.client = ollama.AsyncClient(host=self.api_url)
|
||||
self.api_handles_prompt_template = kwargs.get(
|
||||
"api_handles_prompt_template", self.api_handles_prompt_template
|
||||
)
|
||||
self.allow_thinking = kwargs.get(
|
||||
"allow_thinking", self.allow_thinking
|
||||
)
|
||||
self.allow_thinking = kwargs.get("allow_thinking", self.allow_thinking)
|
||||
|
||||
async def status(self):
|
||||
"""
|
||||
@@ -131,7 +142,7 @@ class OllamaClient(ClientBase):
|
||||
Raises an error if no model name is returned.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
|
||||
if self.processing:
|
||||
self.emit_status()
|
||||
return
|
||||
@@ -140,7 +151,7 @@ class OllamaClient(ClientBase):
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# instead of using the client (which apparently cannot set a timeout per endpoint)
|
||||
# we use httpx to check {api_url}/api/version to see if the server is running
|
||||
@@ -148,7 +159,7 @@ class OllamaClient(ClientBase):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{self.api_url}/api/version", timeout=2)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
# if the server is running, fetch the available models
|
||||
await self.fetch_available_models()
|
||||
except Exception as e:
|
||||
@@ -156,7 +167,7 @@ class OllamaClient(ClientBase):
|
||||
self.connected = False
|
||||
self.emit_status()
|
||||
return
|
||||
|
||||
|
||||
await super().status()
|
||||
|
||||
async def fetch_available_models(self):
|
||||
@@ -165,7 +176,7 @@ class OllamaClient(ClientBase):
|
||||
"""
|
||||
if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL:
|
||||
return self._available_models
|
||||
|
||||
|
||||
response = await self.client.list()
|
||||
models = response.get("models", [])
|
||||
model_names = [model.model for model in models]
|
||||
@@ -182,7 +193,7 @@ class OllamaClient(ClientBase):
|
||||
|
||||
async def get_model_name(self):
|
||||
return self.model_name
|
||||
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if not self.api_handles_prompt_template:
|
||||
return super().prompt_template(system_message, prompt)
|
||||
@@ -201,10 +212,12 @@ class OllamaClient(ClientBase):
|
||||
Tune parameters for Ollama's generate endpoint.
|
||||
"""
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
|
||||
|
||||
# Build stopping strings list
|
||||
parameters["stop"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
|
||||
|
||||
parameters["stop"] = STOPPING_STRINGS + parameters.get(
|
||||
"extra_stopping_strings", []
|
||||
)
|
||||
|
||||
# Ollama uses num_predict instead of max_tokens
|
||||
if "max_tokens" in parameters:
|
||||
parameters["num_predict"] = parameters["max_tokens"]
|
||||
@@ -215,7 +228,7 @@ class OllamaClient(ClientBase):
|
||||
"""
|
||||
# First let parent class handle parameter reroutes and cleanup
|
||||
super().clean_prompt_parameters(parameters)
|
||||
|
||||
|
||||
# Remove our internal parameters
|
||||
if "extra_stopping_strings" in parameters:
|
||||
del parameters["extra_stopping_strings"]
|
||||
@@ -223,7 +236,7 @@ class OllamaClient(ClientBase):
|
||||
del parameters["stopping_strings"]
|
||||
if "stream" in parameters:
|
||||
del parameters["stream"]
|
||||
|
||||
|
||||
# Remove max_tokens as we've already converted it to num_predict
|
||||
if "max_tokens" in parameters:
|
||||
del parameters["max_tokens"]
|
||||
@@ -237,12 +250,12 @@ class OllamaClient(ClientBase):
|
||||
await self.get_model_name()
|
||||
if not self.model_name:
|
||||
raise Exception("No model specified or available in Ollama")
|
||||
|
||||
|
||||
# Prepare options for Ollama
|
||||
options = parameters
|
||||
|
||||
|
||||
options["num_ctx"] = self.max_token_length
|
||||
|
||||
|
||||
try:
|
||||
# Use generate endpoint for completion
|
||||
stream = await self.client.generate(
|
||||
@@ -253,22 +266,21 @@ class OllamaClient(ClientBase):
|
||||
think=self.can_think,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
response = ""
|
||||
|
||||
|
||||
async for part in stream:
|
||||
content = part.response
|
||||
response += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
|
||||
# Extract the response text
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log.error("Ollama generation error", error=str(e), model=self.model_name)
|
||||
raise ErrorAction(
|
||||
message=f"Ollama generation failed: {str(e)}",
|
||||
title="Generation Error"
|
||||
message=f"Ollama generation failed: {str(e)}", title="Generation Error"
|
||||
)
|
||||
|
||||
async def abort_generation(self):
|
||||
@@ -284,7 +296,7 @@ class OllamaClient(ClientBase):
|
||||
Adjusts temperature and repetition_penalty by random values.
|
||||
"""
|
||||
import random
|
||||
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config.get("repetition_penalty", 1.0)
|
||||
|
||||
@@ -302,12 +314,12 @@ class OllamaClient(ClientBase):
|
||||
# Handle model update
|
||||
if kwargs.get("model"):
|
||||
self.model_name = kwargs["model"]
|
||||
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
||||
|
||||
# Re-initialize client if API URL changed or model changed
|
||||
if "api_url" in kwargs or "model" in kwargs:
|
||||
self.set_client(**kwargs)
|
||||
|
||||
|
||||
if "api_handles_prompt_template" in kwargs:
|
||||
self.api_handles_prompt_template = kwargs["api_handles_prompt_template"]
|
||||
|
||||
@@ -108,9 +108,11 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = "gpt-4o"
|
||||
|
||||
|
||||
class ClientConfig(EndpointOverride, BaseClientConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register()
|
||||
class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
"""
|
||||
@@ -123,7 +125,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
# TODO: make this configurable?
|
||||
decensor_enabled = False
|
||||
config_cls = ClientConfig
|
||||
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "OpenAI"
|
||||
title: str = "OpenAI"
|
||||
@@ -132,6 +134,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields()
|
||||
|
||||
def __init__(self, model="gpt-4o", **kwargs):
|
||||
self.model_name = model
|
||||
self.api_key_status = None
|
||||
@@ -181,12 +184,12 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
self.current_status = status
|
||||
|
||||
data={
|
||||
data = {
|
||||
"error_action": error_action.model_dump() if error_action else None,
|
||||
"meta": self.Meta().model_dump(),
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
data.update(self._common_status_data())
|
||||
data.update(self._common_status_data())
|
||||
|
||||
emit(
|
||||
"client_status",
|
||||
@@ -305,32 +308,34 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
system_message = {"role": "system", "content": self.get_system_message(kind)}
|
||||
|
||||
|
||||
# o1 and o3 models don't support system_message
|
||||
if "o1" in self.model_name or "o3" in self.model_name:
|
||||
messages=[human_message]
|
||||
messages = [human_message]
|
||||
# paramters need to be munged
|
||||
# `max_tokens` becomes `max_completion_tokens`
|
||||
if "max_tokens" in parameters:
|
||||
parameters["max_completion_tokens"] = parameters.pop("max_tokens")
|
||||
|
||||
|
||||
# temperature forced to 1
|
||||
if "temperature" in parameters:
|
||||
log.debug(f"{self.model_name} does not support temperature, forcing to 1")
|
||||
log.debug(
|
||||
f"{self.model_name} does not support temperature, forcing to 1"
|
||||
)
|
||||
parameters["temperature"] = 1
|
||||
|
||||
|
||||
unsupported_params = [
|
||||
"presence_penalty",
|
||||
"top_p",
|
||||
]
|
||||
|
||||
|
||||
for param in unsupported_params:
|
||||
if param in parameters:
|
||||
log.debug(f"{self.model_name} does not support {param}, removing")
|
||||
parameters.pop(param)
|
||||
|
||||
|
||||
else:
|
||||
messages=[system_message, human_message]
|
||||
messages = [system_message, human_message]
|
||||
|
||||
self.log.debug(
|
||||
"generate",
|
||||
@@ -346,7 +351,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
stream=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
|
||||
response = ""
|
||||
|
||||
# Iterate over streamed chunks
|
||||
@@ -359,9 +364,9 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
response += content_piece
|
||||
# Incrementally track token usage
|
||||
self.update_request_tokens(self.count_tokens(content_piece))
|
||||
|
||||
#self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
#self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
# self._returned_prompt_tokens = self.prompt_tokens(prompt)
|
||||
# self._returned_response_tokens = self.response_tokens(response)
|
||||
|
||||
# older models don't support json_object response coersion
|
||||
# and often like to return the response wrapped in ```json
|
||||
@@ -382,5 +387,5 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase):
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="OpenAI API: Permission Denied", status="error")
|
||||
return ""
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import random
|
||||
import urllib
|
||||
|
||||
import pydantic
|
||||
import structlog
|
||||
from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError
|
||||
from openai import AsyncOpenAI, PermissionDeniedError
|
||||
|
||||
from talemate.client.base import ClientBase, ExtraField
|
||||
from talemate.client.registry import register
|
||||
@@ -93,7 +92,6 @@ class OpenAICompatibleClient(ClientBase):
|
||||
)
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
|
||||
log.debug(
|
||||
"IS API HANDLING PROMPT TEMPLATE",
|
||||
api_handles_prompt_template=self.api_handles_prompt_template,
|
||||
@@ -130,7 +128,10 @@ class OpenAICompatibleClient(ClientBase):
|
||||
)
|
||||
human_message = {"role": "user", "content": prompt.strip()}
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name, messages=[human_message], stream=False, **parameters
|
||||
model=self.model_name,
|
||||
messages=[human_message],
|
||||
stream=False,
|
||||
**parameters,
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
return self.process_response_for_indirect_coercion(prompt, response)
|
||||
@@ -177,7 +178,7 @@ class OpenAICompatibleClient(ClientBase):
|
||||
|
||||
if "double_coercion" in kwargs:
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
|
||||
if "rate_limit" in kwargs:
|
||||
self.rate_limit = kwargs["rate_limit"]
|
||||
|
||||
|
||||
@@ -21,25 +21,25 @@ AVAILABLE_MODELS = []
|
||||
DEFAULT_MODEL = ""
|
||||
MODELS_FETCHED = False
|
||||
|
||||
|
||||
async def fetch_available_models(api_key: str = None):
|
||||
"""Fetch available models from OpenRouter API"""
|
||||
global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED
|
||||
|
||||
|
||||
if not api_key:
|
||||
return []
|
||||
|
||||
|
||||
if MODELS_FETCHED:
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
|
||||
# Only fetch if we haven't already or if explicitly requested
|
||||
if AVAILABLE_MODELS and not api_key:
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
timeout=10.0
|
||||
"https://openrouter.ai/api/v1/models", timeout=10.0
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
@@ -51,23 +51,26 @@ async def fetch_available_models(api_key: str = None):
|
||||
AVAILABLE_MODELS = sorted(models)
|
||||
log.debug(f"Fetched {len(AVAILABLE_MODELS)} models from OpenRouter")
|
||||
else:
|
||||
log.warning(f"Failed to fetch models from OpenRouter: {response.status_code}")
|
||||
log.warning(
|
||||
f"Failed to fetch models from OpenRouter: {response.status_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from OpenRouter: {e}")
|
||||
|
||||
|
||||
MODELS_FETCHED = True
|
||||
return AVAILABLE_MODELS
|
||||
|
||||
|
||||
|
||||
def fetch_models_sync(event):
|
||||
api_key = event.data.get("openrouter", {}).get("api_key")
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(fetch_available_models(api_key))
|
||||
|
||||
|
||||
handlers["config_saved"].connect(fetch_models_sync)
|
||||
handlers["talemate_started"].connect(fetch_models_sync)
|
||||
|
||||
|
||||
class Defaults(CommonDefaults, pydantic.BaseModel):
|
||||
max_token_length: int = 16384
|
||||
model: str = DEFAULT_MODEL
|
||||
@@ -89,7 +92,9 @@ class OpenRouterClient(ClientBase):
|
||||
name_prefix: str = "OpenRouter"
|
||||
title: str = "OpenRouter"
|
||||
manual_model: bool = True
|
||||
manual_model_choices: list[str] = pydantic.Field(default_factory=lambda: AVAILABLE_MODELS)
|
||||
manual_model_choices: list[str] = pydantic.Field(
|
||||
default_factory=lambda: AVAILABLE_MODELS
|
||||
)
|
||||
requires_prompt_template: bool = False
|
||||
defaults: Defaults = Defaults()
|
||||
|
||||
@@ -169,7 +174,7 @@ class OpenRouterClient(ClientBase):
|
||||
def set_client(self, max_token_length: int = None):
|
||||
# Unlike other clients, we don't need to set up a client instance
|
||||
# We'll use httpx directly in the generate method
|
||||
|
||||
|
||||
if not self.openrouter_api_key:
|
||||
log.error("No OpenRouter API key set")
|
||||
if self.api_key_status:
|
||||
@@ -183,7 +188,7 @@ class OpenRouterClient(ClientBase):
|
||||
|
||||
if max_token_length and not isinstance(max_token_length, int):
|
||||
max_token_length = int(max_token_length)
|
||||
|
||||
|
||||
# Set max token length (default to 16k if not specified)
|
||||
self.max_token_length = max_token_length or 16384
|
||||
|
||||
@@ -221,7 +226,7 @@ class OpenRouterClient(ClientBase):
|
||||
self._models_fetched = True
|
||||
# Update the Meta class with new model choices
|
||||
self.Meta.manual_model_choices = AVAILABLE_MODELS
|
||||
|
||||
|
||||
self.emit_status()
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
@@ -240,13 +245,13 @@ class OpenRouterClient(ClientBase):
|
||||
raise Exception("No OpenRouter API key set")
|
||||
|
||||
prompt, coercion_prompt = self.split_prompt_for_coercion(prompt)
|
||||
|
||||
|
||||
# Prepare messages for chat completion
|
||||
messages = [
|
||||
{"role": "system", "content": self.get_system_message(kind)},
|
||||
{"role": "user", "content": prompt.strip()}
|
||||
{"role": "user", "content": prompt.strip()},
|
||||
]
|
||||
|
||||
|
||||
if coercion_prompt:
|
||||
messages.append({"role": "assistant", "content": coercion_prompt.strip()})
|
||||
|
||||
@@ -255,7 +260,7 @@ class OpenRouterClient(ClientBase):
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
**parameters
|
||||
**parameters,
|
||||
}
|
||||
|
||||
self.log.debug(
|
||||
@@ -264,7 +269,7 @@ class OpenRouterClient(ClientBase):
|
||||
parameters=parameters,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
|
||||
response_text = ""
|
||||
buffer = ""
|
||||
completion_tokens = 0
|
||||
@@ -279,46 +284,52 @@ class OpenRouterClient(ClientBase):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=payload,
|
||||
timeout=120.0 # 2 minute timeout for generation
|
||||
timeout=120.0, # 2 minute timeout for generation
|
||||
) as response:
|
||||
async for chunk in response.aiter_text():
|
||||
buffer += chunk
|
||||
|
||||
|
||||
while True:
|
||||
# Find the next complete SSE line
|
||||
line_end = buffer.find('\n')
|
||||
line_end = buffer.find("\n")
|
||||
if line_end == -1:
|
||||
break
|
||||
|
||||
|
||||
line = buffer[:line_end].strip()
|
||||
buffer = buffer[line_end + 1:]
|
||||
|
||||
if line.startswith('data: '):
|
||||
buffer = buffer[line_end + 1 :]
|
||||
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
if data == '[DONE]':
|
||||
if data == "[DONE]":
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
data_obj = json.loads(data)
|
||||
content = data_obj["choices"][0]["delta"].get("content")
|
||||
content = data_obj["choices"][0]["delta"].get(
|
||||
"content"
|
||||
)
|
||||
usage = data_obj.get("usage", {})
|
||||
completion_tokens += usage.get("completion_tokens", 0)
|
||||
completion_tokens += usage.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
prompt_tokens += usage.get("prompt_tokens", 0)
|
||||
if content:
|
||||
response_text += content
|
||||
# Update tokens as content streams in
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
|
||||
self.update_request_tokens(
|
||||
self.count_tokens(content)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
# Extract the response content
|
||||
response_content = response_text
|
||||
self._returned_prompt_tokens = prompt_tokens
|
||||
self._returned_response_tokens = completion_tokens
|
||||
|
||||
|
||||
return response_content
|
||||
|
||||
|
||||
except httpx.ConnectTimeout:
|
||||
self.log.error("OpenRouter API timeout")
|
||||
emit("status", message="OpenRouter API: Request timed out", status="error")
|
||||
@@ -326,4 +337,4 @@ class OpenRouterClient(ClientBase):
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message=f"OpenRouter API Error: {str(e)}", status="error")
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -16,11 +16,6 @@ __all__ = [
|
||||
"preset_for_kind",
|
||||
"make_kind",
|
||||
"max_tokens_for_kind",
|
||||
"PRESET_TALEMATE_CONVERSATION",
|
||||
"PRESET_TALEMATE_CREATOR",
|
||||
"PRESET_LLAMA_PRECISE",
|
||||
"PRESET_DIVINE_INTELLECT",
|
||||
"PRESET_SIMPLE_1",
|
||||
]
|
||||
|
||||
log = structlog.get_logger("talemate.client.presets")
|
||||
@@ -42,27 +37,31 @@ def sync_config(event):
|
||||
)
|
||||
CONFIG["inference_groups"] = {
|
||||
group: InferencePresetGroup(**data)
|
||||
for group, data in event.data.get("presets", {}).get("inference_groups", {}).items()
|
||||
for group, data in event.data.get("presets", {})
|
||||
.get("inference_groups", {})
|
||||
.items()
|
||||
}
|
||||
|
||||
|
||||
handlers["config_saved"].connect(sync_config)
|
||||
|
||||
|
||||
def get_inference_parameters(preset_name: str, group:str|None = None) -> dict:
|
||||
def get_inference_parameters(preset_name: str, group: str | None = None) -> dict:
|
||||
"""
|
||||
Returns the inference parameters for the given preset name.
|
||||
"""
|
||||
|
||||
|
||||
presets = CONFIG["inference"].model_dump()
|
||||
|
||||
|
||||
if group:
|
||||
try:
|
||||
group_presets = CONFIG["inference_groups"].get(group).model_dump()
|
||||
presets.update(group_presets["presets"])
|
||||
except AttributeError:
|
||||
log.warning(f"Invalid preset group referenced: {group}. Falling back to defaults.")
|
||||
|
||||
log.warning(
|
||||
f"Invalid preset group referenced: {group}. Falling back to defaults."
|
||||
)
|
||||
|
||||
if preset_name in presets:
|
||||
return presets[preset_name]
|
||||
|
||||
@@ -145,7 +144,7 @@ def preset_for_kind(kind: str, client: "ClientBase") -> dict:
|
||||
presets=CONFIG["inference"],
|
||||
)
|
||||
preset_name = "scene_direction"
|
||||
|
||||
|
||||
set_client_context_attribute("inference_preset", preset_name)
|
||||
|
||||
return get_inference_parameters(preset_name, client.preset_group)
|
||||
@@ -197,29 +196,26 @@ def max_tokens_for_kind(kind: str, total_budget: int) -> int:
|
||||
return value
|
||||
if token_value is not None:
|
||||
return token_value
|
||||
|
||||
|
||||
# finally check if splitting last item off of _ is a number, and then just
|
||||
# return that number
|
||||
kind_split = kind.split("_")[-1]
|
||||
if kind_split.isdigit():
|
||||
return int(kind_split)
|
||||
|
||||
|
||||
return 150 # Default value if none of the kinds match
|
||||
|
||||
|
||||
|
||||
def make_kind(action_type: str, length: int, expect_json:bool=False) -> str:
|
||||
def make_kind(action_type: str, length: int, expect_json: bool = False) -> str:
|
||||
"""
|
||||
Creates a kind string based on the preset_arch_type and length.
|
||||
"""
|
||||
|
||||
if action_type == "analyze" and not expect_json:
|
||||
kind = f"investigate"
|
||||
kind = "investigate"
|
||||
else:
|
||||
kind = action_type
|
||||
|
||||
|
||||
kind = f"{kind}_{length}"
|
||||
|
||||
return kind
|
||||
|
||||
|
||||
|
||||
return kind
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
from limits.strategies import MovingWindowRateLimiter
|
||||
from limits.storage import MemoryStorage
|
||||
from limits import parse, RateLimitItemPerMinute
|
||||
from limits import RateLimitItemPerMinute
|
||||
import time
|
||||
|
||||
__all__ = ["CounterRateLimiter"]
|
||||
|
||||
|
||||
class CounterRateLimiter:
|
||||
def __init__(self, rate_per_minute:int=99999, identifier:str="ratelimit"):
|
||||
def __init__(self, rate_per_minute: int = 99999, identifier: str = "ratelimit"):
|
||||
self.storage = MemoryStorage()
|
||||
self.limiter = MovingWindowRateLimiter(self.storage)
|
||||
self.rate = RateLimitItemPerMinute(rate_per_minute, 1)
|
||||
self.identifier = identifier
|
||||
|
||||
def update_rate_limit(self, rate_per_minute:int):
|
||||
|
||||
def update_rate_limit(self, rate_per_minute: int):
|
||||
"""Update the rate limit with a new value"""
|
||||
self.rate = RateLimitItemPerMinute(rate_per_minute, 1)
|
||||
|
||||
|
||||
def increment(self) -> bool:
|
||||
limiter = self.limiter
|
||||
rate = self.rate
|
||||
return limiter.hit(rate, self.identifier)
|
||||
|
||||
|
||||
def reset_time(self) -> float:
|
||||
"""
|
||||
Returns the time in seconds until the rate limit is reset
|
||||
"""
|
||||
|
||||
|
||||
window = self.limiter.get_window_stats(self.rate, self.identifier)
|
||||
return window.reset_time - time.time()
|
||||
return window.reset_time - time.time()
|
||||
|
||||
@@ -18,56 +18,68 @@ __all__ = [
|
||||
"EndpointOverrideAPIKeyField",
|
||||
]
|
||||
|
||||
|
||||
def endpoint_override_extra_fields():
|
||||
return {
|
||||
"override_base_url": EndpointOverrideBaseURLField(),
|
||||
"override_api_key": EndpointOverrideAPIKeyField(),
|
||||
}
|
||||
|
||||
|
||||
class EndpointOverride(pydantic.BaseModel):
|
||||
override_base_url: str | None = None
|
||||
override_api_key: str | None = None
|
||||
|
||||
|
||||
|
||||
class EndpointOverrideGroup(FieldGroup):
|
||||
name: str = "endpoint_override"
|
||||
label: str = "Endpoint Override"
|
||||
description: str = ("Override the default base URL used by this client to access the {client_type} service API.\n\n"
|
||||
"IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. "
|
||||
"If the override endpoint requires an API key, enter it below.")
|
||||
description: str = (
|
||||
"Override the default base URL used by this client to access the {client_type} service API.\n\n"
|
||||
"IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. "
|
||||
"If the override endpoint requires an API key, enter it below."
|
||||
)
|
||||
icon: str = "mdi-api"
|
||||
|
||||
|
||||
class EndpointOverrideField(ExtraField):
|
||||
group: EndpointOverrideGroup = pydantic.Field(default_factory=EndpointOverrideGroup)
|
||||
|
||||
|
||||
|
||||
class EndpointOverrideBaseURLField(EndpointOverrideField):
|
||||
name: str = "override_base_url"
|
||||
type: str = "text"
|
||||
label: str = "Base URL"
|
||||
required: bool = False
|
||||
description: str = "Override the base URL for the remote service"
|
||||
|
||||
|
||||
|
||||
class EndpointOverrideAPIKeyField(EndpointOverrideField):
|
||||
name: str = "override_api_key"
|
||||
type: str = "password"
|
||||
label: str = "API Key"
|
||||
required: bool = False
|
||||
description: str = "Override the API key for the remote service"
|
||||
note: ux_schema.Note = pydantic.Field(default_factory=lambda: ux_schema.Note(
|
||||
text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.",
|
||||
color="warning",
|
||||
))
|
||||
note: ux_schema.Note = pydantic.Field(
|
||||
default_factory=lambda: ux_schema.Note(
|
||||
text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.",
|
||||
color="warning",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class EndpointOverrideMixin:
|
||||
override_base_url: str | None = None
|
||||
override_api_key: str | None = None
|
||||
|
||||
|
||||
def set_client_api_key(self, api_key: str | None):
|
||||
if getattr(self, "client", None):
|
||||
try:
|
||||
self.client.api_key = api_key
|
||||
except Exception as e:
|
||||
log.error("Error setting client API key", error=e, client=self.client_type)
|
||||
log.error(
|
||||
"Error setting client API key", error=e, client=self.client_type
|
||||
)
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
@@ -84,14 +96,17 @@ class EndpointOverrideMixin:
|
||||
@property
|
||||
def endpoint_override_base_url_configured(self) -> bool:
|
||||
return self.override_base_url and self.override_base_url.strip()
|
||||
|
||||
|
||||
@property
|
||||
def endpoint_override_api_key_configured(self) -> bool:
|
||||
return self.override_api_key and self.override_api_key.strip()
|
||||
|
||||
|
||||
@property
|
||||
def endpoint_override_fully_configured(self) -> bool:
|
||||
return self.endpoint_override_base_url_configured and self.endpoint_override_api_key_configured
|
||||
return (
|
||||
self.endpoint_override_base_url_configured
|
||||
and self.endpoint_override_api_key_configured
|
||||
)
|
||||
|
||||
def _reconfigure_endpoint_override(self, **kwargs):
|
||||
if "override_base_url" in kwargs:
|
||||
@@ -100,13 +115,13 @@ class EndpointOverrideMixin:
|
||||
if getattr(self, "client", None) and orig != self.override_base_url:
|
||||
log.info("Reconfiguring client base URL", new=self.override_base_url)
|
||||
self.set_client(kwargs.get("max_token_length"))
|
||||
|
||||
|
||||
if "override_api_key" in kwargs:
|
||||
self.override_api_key = kwargs["override_api_key"]
|
||||
self.set_client_api_key(self.override_api_key)
|
||||
|
||||
class RemoteServiceMixin:
|
||||
|
||||
class RemoteServiceMixin:
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
if "<|BOT|>" in prompt:
|
||||
_, right = prompt.split("<|BOT|>", 1)
|
||||
|
||||
@@ -4,8 +4,6 @@ connection for the pod. This is a simple wrapper around the runpod module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
import runpod
|
||||
|
||||
@@ -27,7 +27,6 @@ PROMPT_TEMPLATE_MAP = {
|
||||
"world_state": "world_state.system-analyst-no-decensor",
|
||||
"summarize": "summarizer.system-no-decensor",
|
||||
"visualize": "visual.system-no-decensor",
|
||||
|
||||
# contains some minor attempts at keeping the LLM from generating
|
||||
# refusals to generate certain types of content
|
||||
"roleplay_decensor": "conversation.system",
|
||||
@@ -42,32 +41,39 @@ PROMPT_TEMPLATE_MAP = {
|
||||
"visualize_decensor": "visual.system",
|
||||
}
|
||||
|
||||
|
||||
def cache_all() -> dict:
|
||||
for key in PROMPT_TEMPLATE_MAP:
|
||||
render_prompt(key)
|
||||
return RENDER_CACHE.copy()
|
||||
|
||||
def render_prompt(kind:str, decensor:bool=False):
|
||||
|
||||
def render_prompt(kind: str, decensor: bool = False):
|
||||
# work around circular import issue
|
||||
# TODO: refactor to avoid circular import
|
||||
from talemate.prompts import Prompt
|
||||
|
||||
|
||||
if kind not in PROMPT_TEMPLATE_MAP:
|
||||
log.warning(f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}")
|
||||
log.warning(
|
||||
f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}"
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
if decensor:
|
||||
key = f"{kind}_decensor"
|
||||
else:
|
||||
key = kind
|
||||
|
||||
|
||||
if key not in PROMPT_TEMPLATE_MAP:
|
||||
log.warning(f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}", key=key)
|
||||
log.warning(
|
||||
f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}",
|
||||
key=key,
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
if key in RENDER_CACHE:
|
||||
return RENDER_CACHE[key]
|
||||
|
||||
|
||||
prompt = str(Prompt.get(PROMPT_TEMPLATE_MAP[key]))
|
||||
|
||||
RENDER_CACHE[key] = prompt
|
||||
@@ -77,17 +83,17 @@ def render_prompt(kind:str, decensor:bool=False):
|
||||
class SystemPrompts(pydantic.BaseModel):
|
||||
"""
|
||||
System prompts and a normalized the way to access them.
|
||||
|
||||
|
||||
Allows specification of a parent "SystemPrompts" instance that will be
|
||||
used as a fallback, and if not so specified, will default to the
|
||||
system prompts in the globals via lambda functions that render
|
||||
the templates.
|
||||
|
||||
|
||||
The globals that exist now will be deprecated in favor of this later.
|
||||
"""
|
||||
|
||||
|
||||
parent: "SystemPrompts | None" = pydantic.Field(default=None, exclude=True)
|
||||
|
||||
|
||||
roleplay: str | None = None
|
||||
narrator: str | None = None
|
||||
creator: str | None = None
|
||||
@@ -98,7 +104,7 @@ class SystemPrompts(pydantic.BaseModel):
|
||||
world_state: str | None = None
|
||||
summarize: str | None = None
|
||||
visualize: str | None = None
|
||||
|
||||
|
||||
roleplay_decensor: str | None = None
|
||||
narrator_decensor: str | None = None
|
||||
creator_decensor: str | None = None
|
||||
@@ -109,64 +115,61 @@ class SystemPrompts(pydantic.BaseModel):
|
||||
world_state_decensor: str | None = None
|
||||
summarize_decensor: str | None = None
|
||||
visualize_decensor: str | None = None
|
||||
|
||||
|
||||
class Config:
|
||||
exclude_none = True
|
||||
exclude_unset = True
|
||||
|
||||
|
||||
@property
|
||||
def defaults(self) -> dict:
|
||||
return RENDER_CACHE.copy()
|
||||
|
||||
def alias(self, alias:str) -> str:
|
||||
|
||||
|
||||
def alias(self, alias: str) -> str:
|
||||
if alias in PROMPT_TEMPLATE_MAP:
|
||||
return alias
|
||||
|
||||
if "narrate" in alias:
|
||||
return "narrator"
|
||||
|
||||
|
||||
if "direction" in alias or "director" in alias:
|
||||
return "director"
|
||||
|
||||
|
||||
if "create" in alias or "creative" in alias:
|
||||
return "creator"
|
||||
|
||||
|
||||
if "conversation" in alias or "roleplay" in alias:
|
||||
return "roleplay"
|
||||
|
||||
|
||||
if "basic" in alias:
|
||||
return "basic"
|
||||
|
||||
|
||||
if "edit" in alias:
|
||||
return "editor"
|
||||
|
||||
|
||||
if "world_state" in alias:
|
||||
return "world_state"
|
||||
|
||||
|
||||
if "analyze_freeform" in alias or "investigate" in alias:
|
||||
return "analyst_freeform"
|
||||
|
||||
|
||||
if "analyze" in alias or "analyst" in alias or "analytical" in alias:
|
||||
return "analyst"
|
||||
|
||||
|
||||
if "summarize" in alias or "summarization" in alias:
|
||||
return "summarize"
|
||||
|
||||
|
||||
if "visual" in alias:
|
||||
return "visualize"
|
||||
|
||||
|
||||
return alias
|
||||
|
||||
|
||||
def get(self, kind:str, decensor:bool=False) -> str:
|
||||
|
||||
|
||||
def get(self, kind: str, decensor: bool = False) -> str:
|
||||
kind = self.alias(kind)
|
||||
|
||||
|
||||
key = f"{kind}_decensor" if decensor else kind
|
||||
|
||||
|
||||
if getattr(self, key):
|
||||
return getattr(self, key)
|
||||
if self.parent is not None:
|
||||
return self.parent.get(kind, decensor)
|
||||
return render_prompt(kind, decensor)
|
||||
return render_prompt(kind, decensor)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import random
|
||||
from typing import Literal
|
||||
import json
|
||||
import httpx
|
||||
import pydantic
|
||||
@@ -103,7 +102,6 @@ class TabbyAPIClient(ClientBase):
|
||||
)
|
||||
|
||||
def prompt_template(self, system_message: str, prompt: str):
|
||||
|
||||
log.debug(
|
||||
"IS API HANDLING PROMPT TEMPLATE",
|
||||
api_handles_prompt_template=self.api_handles_prompt_template,
|
||||
@@ -196,22 +194,18 @@ class TabbyAPIClient(ClientBase):
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120.0
|
||||
"POST", url, headers=headers, json=payload, timeout=120.0
|
||||
) as response:
|
||||
async for chunk in response.aiter_text():
|
||||
buffer += chunk
|
||||
|
||||
while True:
|
||||
line_end = buffer.find('\n')
|
||||
line_end = buffer.find("\n")
|
||||
if line_end == -1:
|
||||
break
|
||||
|
||||
line = buffer[:line_end].strip()
|
||||
buffer = buffer[line_end + 1:]
|
||||
buffer = buffer[line_end + 1 :]
|
||||
|
||||
if not line:
|
||||
continue
|
||||
@@ -226,7 +220,7 @@ class TabbyAPIClient(ClientBase):
|
||||
|
||||
choice = data_obj.get("choices", [{}])[0]
|
||||
|
||||
# Chat completions use delta -> content.
|
||||
# Chat completions use delta -> content.
|
||||
delta = choice.get("delta", {})
|
||||
content = (
|
||||
delta.get("content")
|
||||
@@ -235,12 +229,16 @@ class TabbyAPIClient(ClientBase):
|
||||
)
|
||||
|
||||
usage = data_obj.get("usage", {})
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
completion_tokens = usage.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
|
||||
if content:
|
||||
response_text += content
|
||||
self.update_request_tokens(self.count_tokens(content))
|
||||
self.update_request_tokens(
|
||||
self.count_tokens(content)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# ignore malformed json chunks
|
||||
pass
|
||||
@@ -251,7 +249,9 @@ class TabbyAPIClient(ClientBase):
|
||||
|
||||
if is_chat:
|
||||
# Process indirect coercion
|
||||
response_text = self.process_response_for_indirect_coercion(prompt, response_text)
|
||||
response_text = self.process_response_for_indirect_coercion(
|
||||
prompt, response_text
|
||||
)
|
||||
|
||||
return response_text
|
||||
|
||||
@@ -265,7 +265,9 @@ class TabbyAPIClient(ClientBase):
|
||||
return ""
|
||||
except Exception as e:
|
||||
self.log.error("generate error", e=e)
|
||||
emit("status", message="Error during generation (check logs)", status="error")
|
||||
emit(
|
||||
"status", message="Error during generation (check logs)", status="error"
|
||||
)
|
||||
return ""
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
@@ -287,7 +289,7 @@ class TabbyAPIClient(ClientBase):
|
||||
self.double_coercion = kwargs["double_coercion"]
|
||||
|
||||
self._reconfigure_common_parameters(**kwargs)
|
||||
|
||||
|
||||
self.set_client(**kwargs)
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
|
||||
@@ -8,7 +8,7 @@ import httpx
|
||||
import structlog
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults
|
||||
from talemate.client.registry import register
|
||||
|
||||
log = structlog.get_logger("talemate.client.textgenwebui")
|
||||
@@ -103,7 +103,6 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111")
|
||||
|
||||
def finalize_llama3(self, parameters: dict, prompt: str) -> tuple[str, bool]:
|
||||
|
||||
if "<|eot_id|>" not in prompt:
|
||||
return prompt, False
|
||||
|
||||
@@ -122,7 +121,6 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
return prompt, True
|
||||
|
||||
def finalize_YI(self, parameters: dict, prompt: str) -> tuple[str, bool]:
|
||||
|
||||
if not self.model_name:
|
||||
return prompt, False
|
||||
|
||||
@@ -141,7 +139,6 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
return prompt, True
|
||||
|
||||
async def get_model_name(self):
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.api_url}/v1/internal/model/info",
|
||||
@@ -170,14 +167,16 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self._generate, prompt, parameters, kind)
|
||||
return await loop.run_in_executor(
|
||||
None, self._generate, prompt, parameters, kind
|
||||
)
|
||||
|
||||
def _generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
|
||||
response = ""
|
||||
parameters["stream"] = True
|
||||
stream_response = requests.post(
|
||||
@@ -188,17 +187,16 @@ class TextGeneratorWebuiClient(ClientBase):
|
||||
stream=True,
|
||||
)
|
||||
stream_response.raise_for_status()
|
||||
|
||||
|
||||
sse = sseclient.SSEClient(stream_response)
|
||||
|
||||
|
||||
for event in sse.events():
|
||||
payload = json.loads(event.data)
|
||||
chunk = payload['choices'][0]['text']
|
||||
chunk = payload["choices"][0]["text"]
|
||||
response += chunk
|
||||
self.update_request_tokens(self.count_tokens(chunk))
|
||||
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import copy
|
||||
import random
|
||||
from urllib.parse import urljoin as _urljoin
|
||||
|
||||
__all__ = ["urljoin"]
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
from .base import TalemateCommand
|
||||
from .cmd_characters import *
|
||||
from .cmd_debug_tools import *
|
||||
from .cmd_rebuild_archive import CmdRebuildArchive
|
||||
from .cmd_rename import CmdRename
|
||||
from .cmd_regenerate import *
|
||||
from .cmd_reset import CmdReset
|
||||
from .cmd_save import CmdSave
|
||||
from .cmd_save_as import CmdSaveAs
|
||||
from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene
|
||||
from .cmd_time_util import *
|
||||
from .cmd_tts import *
|
||||
from .cmd_world_state import *
|
||||
from .manager import Manager
|
||||
from .base import TalemateCommand # noqa: F401
|
||||
from .cmd_characters import CmdActivateCharacter, CmdDeactivateCharacter # noqa: F401
|
||||
from .cmd_debug_tools import (
|
||||
CmdPromptChangeSectioning, # noqa: F401
|
||||
CmdSummarizerUpdateLayeredHistory, # noqa: F401
|
||||
CmdSummarizerResetLayeredHistory, # noqa: F401
|
||||
CmdSummarizerContextInvestigation, # noqa: F401
|
||||
)
|
||||
from .cmd_rebuild_archive import CmdRebuildArchive # noqa: F401
|
||||
from .cmd_rename import CmdRename # noqa: F401
|
||||
from .cmd_regenerate import CmdRegenerate # noqa: F401
|
||||
from .cmd_reset import CmdReset # noqa: F401
|
||||
from .cmd_save import CmdSave # noqa: F401
|
||||
from .cmd_save_as import CmdSaveAs # noqa: F401
|
||||
from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene # noqa: F401
|
||||
from .cmd_time_util import CmdAdvanceTime # noqa: F401
|
||||
from .cmd_tts import CmdTestTTS # noqa: F401
|
||||
from .cmd_world_state import (
|
||||
CmdAddReinforcement, # noqa: F401
|
||||
CmdApplyWorldStateTemplate, # noqa: F401
|
||||
CmdCheckPinConditions, # noqa: F401
|
||||
CmdDetermineCharacterDevelopment, # noqa: F401
|
||||
CmdRemoveReinforcement, # noqa: F401
|
||||
CmdSummarizeAndPin, # noqa: F401
|
||||
CmdUpdateReinforcements, # noqa: F401
|
||||
)
|
||||
from .manager import Manager # noqa: F401
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -19,7 +17,6 @@ __all__ = [
|
||||
log = structlog.get_logger("talemate.commands.cmd_debug_tools")
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdPromptChangeSectioning(TalemateCommand):
|
||||
"""
|
||||
@@ -95,7 +92,8 @@ class CmdSummarizerUpdateLayeredHistory(TalemateCommand):
|
||||
summarizer = get_agent("summarizer")
|
||||
|
||||
await summarizer.summarize_to_layered_history()
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdSummarizerResetLayeredHistory(TalemateCommand):
|
||||
"""
|
||||
@@ -108,16 +106,17 @@ class CmdSummarizerResetLayeredHistory(TalemateCommand):
|
||||
|
||||
async def run(self):
|
||||
summarizer = get_agent("summarizer")
|
||||
|
||||
|
||||
# if arg is provided remove the last n layers
|
||||
if self.args:
|
||||
n = int(self.args[0])
|
||||
self.scene.layered_history = self.scene.layered_history[:-n]
|
||||
else:
|
||||
self.scene.layered_history = []
|
||||
|
||||
|
||||
await summarizer.summarize_to_layered_history()
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdSummarizerContextInvestigation(TalemateCommand):
|
||||
"""
|
||||
@@ -135,11 +134,10 @@ class CmdSummarizerContextInvestigation(TalemateCommand):
|
||||
if not self.args:
|
||||
self.emit("system", "You must specify a query")
|
||||
return
|
||||
|
||||
|
||||
await summarizer.request_context_investigations(self.args[0], max_calls=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdMemoryCompareStrings(TalemateCommand):
|
||||
"""
|
||||
@@ -156,13 +154,14 @@ class CmdMemoryCompareStrings(TalemateCommand):
|
||||
if not self.args:
|
||||
self.emit("system", "You must specify two strings to compare")
|
||||
return
|
||||
|
||||
|
||||
string1 = self.args[0]
|
||||
string2 = self.args[1]
|
||||
|
||||
result = await memory.compare_strings(string1, string2)
|
||||
self.emit("system", f"The strings are {result['cosine_similarity']} similar")
|
||||
|
||||
|
||||
|
||||
@register
|
||||
class CmdRunEditorRevision(TalemateCommand):
|
||||
"""
|
||||
@@ -172,14 +171,13 @@ class CmdRunEditorRevision(TalemateCommand):
|
||||
name = "run_editor_revision"
|
||||
description = "Run the editor revision"
|
||||
aliases = ["run_revision"]
|
||||
|
||||
|
||||
async def run(self):
|
||||
editor = get_agent("editor")
|
||||
scene = self.scene
|
||||
|
||||
|
||||
last_message = scene.history[-1]
|
||||
|
||||
|
||||
result = await editor.revision_detect_bad_prose(str(last_message))
|
||||
|
||||
|
||||
self.emit("system", f"Result: {result}")
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import emit
|
||||
@@ -29,7 +27,7 @@ class CmdRebuildArchive(TalemateCommand):
|
||||
]
|
||||
|
||||
self.scene.ts = "PT0S"
|
||||
|
||||
|
||||
memory.delete({"typ": "history"})
|
||||
|
||||
entries = 0
|
||||
@@ -42,9 +40,9 @@ class CmdRebuildArchive(TalemateCommand):
|
||||
)
|
||||
more = await summarizer.agent.build_archive(self.scene)
|
||||
self.scene.sync_time()
|
||||
|
||||
|
||||
entries += 1
|
||||
|
||||
|
||||
if not more:
|
||||
break
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.emit import emit, wait_for_input, wait_for_input_yesno
|
||||
from talemate.emit import wait_for_input_yesno
|
||||
from talemate.exceptions import ResetScene
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
|
||||
|
||||
@@ -22,10 +22,13 @@ class CmdSetEnvironmentToScene(TalemateCommand):
|
||||
player_character = self.scene.get_player_character()
|
||||
|
||||
if not player_character:
|
||||
self.system_message("No characters found - cannot switch to gameplay mode.", meta={
|
||||
"icon": "mdi-alert",
|
||||
"color": "warning",
|
||||
})
|
||||
self.system_message(
|
||||
"No characters found - cannot switch to gameplay mode.",
|
||||
meta={
|
||||
"icon": "mdi-alert",
|
||||
"color": "warning",
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
self.scene.set_environment("scene")
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
Commands to manage scene timescale
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import isodate
|
||||
|
||||
import talemate.instance as instance
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from talemate.commands.base import TalemateCommand
|
||||
from talemate.commands.manager import register
|
||||
from talemate.instance import get_agent
|
||||
from talemate.prompts.base import set_default_sectioning_handler
|
||||
|
||||
__all__ = [
|
||||
"CmdTestTTS",
|
||||
|
||||
@@ -109,8 +109,6 @@ class CmdUpdateReinforcements(TalemateCommand):
|
||||
aliases = ["ws_ur"]
|
||||
|
||||
async def run(self):
|
||||
scene = self.scene
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
|
||||
await world_state.update_reinforcements(force=True)
|
||||
@@ -234,18 +232,16 @@ class CmdDetermineCharacterDevelopment(TalemateCommand):
|
||||
scene = self.scene
|
||||
|
||||
world_state = get_agent("world_state")
|
||||
creator = get_agent("creator")
|
||||
|
||||
if not len(self.args):
|
||||
raise ValueError("No character name provided.")
|
||||
|
||||
character_name = self.args[0]
|
||||
|
||||
|
||||
character = scene.get_character(character_name)
|
||||
|
||||
|
||||
if not character:
|
||||
raise ValueError(f"Character {character_name} not found.")
|
||||
|
||||
instructions = await world_state.determine_character_development(character)
|
||||
|
||||
# updates = await creator.update_character_sheet(character, instructions)
|
||||
await world_state.determine_character_development(character)
|
||||
# updates = await creator.update_character_sheet(character, instructions)
|
||||
|
||||
@@ -36,7 +36,7 @@ class Manager(Emitter):
|
||||
aliases[alias] = name.replace("cmd_", "")
|
||||
return aliases
|
||||
|
||||
async def execute(self, cmd, emit_on_unknown:bool = True, state:dict = None):
|
||||
async def execute(self, cmd, emit_on_unknown: bool = True, state: dict = None):
|
||||
# commands start with ! and are followed by a command name
|
||||
cmd = cmd.strip()
|
||||
cmd_args = ""
|
||||
@@ -56,7 +56,6 @@ class Manager(Emitter):
|
||||
|
||||
for command_cls in self.command_classes:
|
||||
if command_cls.is_command(cmd_name):
|
||||
|
||||
if command_cls.argument_cls:
|
||||
cmd_kwargs = json.loads(cmd_args_unsplit)
|
||||
cmd_args = []
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import datetime
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union, Literal
|
||||
@@ -6,8 +5,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union,
|
||||
import pydantic
|
||||
import structlog
|
||||
import yaml
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from talemate.agents.registry import get_agent_class
|
||||
@@ -43,7 +41,7 @@ class Client(BaseModel):
|
||||
rate_limit: Union[int, None] = None
|
||||
data_format: Literal["json", "yaml"] | None = None
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
system_prompts: SystemPrompts = SystemPrompts()
|
||||
preset_group: str | None = None
|
||||
|
||||
@@ -165,6 +163,7 @@ class DeepSeekConfig(BaseModel):
|
||||
class OpenRouterConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
|
||||
class RunPodConfig(BaseModel):
|
||||
api_key: Union[str, None] = None
|
||||
|
||||
@@ -202,6 +201,7 @@ class RecentScene(BaseModel):
|
||||
date: str
|
||||
cover_image: Union[Asset, None] = None
|
||||
|
||||
|
||||
class EmbeddingFunctionPreset(BaseModel):
|
||||
embeddings: str = "sentence-transformer"
|
||||
model: str = "all-MiniLM-L6-v2"
|
||||
@@ -217,12 +217,11 @@ class EmbeddingFunctionPreset(BaseModel):
|
||||
client: str | None = None
|
||||
|
||||
|
||||
|
||||
def generate_chromadb_presets() -> dict[str, EmbeddingFunctionPreset]:
|
||||
"""
|
||||
Returns a dict of default embedding presets
|
||||
"""
|
||||
|
||||
|
||||
return {
|
||||
"default": EmbeddingFunctionPreset(),
|
||||
"Alibaba-NLP/gte-base-en-v1.5": EmbeddingFunctionPreset(
|
||||
@@ -274,25 +273,24 @@ class InferenceParameters(BaseModel):
|
||||
frequency_penalty: float | None = 0.05
|
||||
repetition_penalty: float | None = 1.0
|
||||
repetition_penalty_range: int | None = 1024
|
||||
|
||||
|
||||
xtc_threshold: float | None = 0.1
|
||||
xtc_probability: float | None = 0.0
|
||||
|
||||
|
||||
dry_multiplier: float | None = 0.0
|
||||
dry_base: float | None = 1.75
|
||||
dry_allowed_length: int | None = 2
|
||||
dry_sequence_breakers: str | None = '"\\n", ":", "\\"", "*"'
|
||||
|
||||
|
||||
smoothing_factor: float | None = 0.0
|
||||
smoothing_curve: float | None = 1.0
|
||||
|
||||
|
||||
# this determines whether or not it should be persisted
|
||||
# to the config file
|
||||
changed: bool = False
|
||||
|
||||
|
||||
class InferencePresets(BaseModel):
|
||||
|
||||
analytical: InferenceParameters = InferenceParameters(
|
||||
temperature=0.7,
|
||||
presence_penalty=0,
|
||||
@@ -325,19 +323,26 @@ class InferencePresets(BaseModel):
|
||||
presence_penalty=0.0,
|
||||
)
|
||||
|
||||
|
||||
class InferencePresetGroup(BaseModel):
|
||||
name: str
|
||||
presets: InferencePresets
|
||||
|
||||
|
||||
class Presets(BaseModel):
|
||||
inference_defaults: InferencePresets = InferencePresets()
|
||||
inference: InferencePresets = InferencePresets()
|
||||
|
||||
inference_groups: dict[str, InferencePresetGroup] = pydantic.Field(default_factory=dict)
|
||||
|
||||
embeddings_defaults: dict[str, EmbeddingFunctionPreset] = pydantic.Field(default_factory=generate_chromadb_presets)
|
||||
embeddings: dict[str, EmbeddingFunctionPreset] = pydantic.Field(default_factory=generate_chromadb_presets)
|
||||
|
||||
inference_groups: dict[str, InferencePresetGroup] = pydantic.Field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
embeddings_defaults: dict[str, EmbeddingFunctionPreset] = pydantic.Field(
|
||||
default_factory=generate_chromadb_presets
|
||||
)
|
||||
embeddings: dict[str, EmbeddingFunctionPreset] = pydantic.Field(
|
||||
default_factory=generate_chromadb_presets
|
||||
)
|
||||
|
||||
|
||||
def gnerate_intro_scenes():
|
||||
@@ -471,28 +476,28 @@ AnnotatedClient = Annotated[
|
||||
class HistoryMessageStyle(BaseModel):
|
||||
italic: bool = False
|
||||
bold: bool = False
|
||||
|
||||
|
||||
# Leave None for default color
|
||||
color: str | None = None
|
||||
color: str | None = None
|
||||
|
||||
|
||||
class HidableHistoryMessageStyle(HistoryMessageStyle):
|
||||
# certain messages can be hidden, but all messages are shown by default
|
||||
show: bool = True
|
||||
|
||||
|
||||
|
||||
class SceneAppearance(BaseModel):
|
||||
|
||||
narrator_messages: HistoryMessageStyle = HistoryMessageStyle(italic=True)
|
||||
character_messages: HistoryMessageStyle = HistoryMessageStyle()
|
||||
director_messages: HidableHistoryMessageStyle = HidableHistoryMessageStyle()
|
||||
time_messages: HistoryMessageStyle = HistoryMessageStyle()
|
||||
context_investigation_messages: HidableHistoryMessageStyle = HidableHistoryMessageStyle()
|
||||
|
||||
context_investigation_messages: HidableHistoryMessageStyle = (
|
||||
HidableHistoryMessageStyle()
|
||||
)
|
||||
|
||||
|
||||
class Appearance(BaseModel):
|
||||
|
||||
scene: SceneAppearance = SceneAppearance()
|
||||
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
@@ -505,7 +510,7 @@ class Config(BaseModel):
|
||||
creator: CreatorConfig = CreatorConfig()
|
||||
|
||||
openai: OpenAIConfig = OpenAIConfig()
|
||||
|
||||
|
||||
deepseek: DeepSeekConfig = DeepSeekConfig()
|
||||
|
||||
mistralai: MistralAIConfig = MistralAIConfig()
|
||||
@@ -531,11 +536,11 @@ class Config(BaseModel):
|
||||
recent_scenes: RecentScenes = RecentScenes()
|
||||
|
||||
presets: Presets = Presets()
|
||||
|
||||
|
||||
appearance: Appearance = Appearance()
|
||||
|
||||
|
||||
system_prompts: SystemPrompts = SystemPrompts()
|
||||
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
@@ -595,18 +600,18 @@ def save_config(config, file_path: str = "./config.yaml"):
|
||||
# we dont want to persist the following, so we drop them:
|
||||
# - presets.inference_defaults
|
||||
# - presets.embeddings_defaults
|
||||
|
||||
|
||||
if "inference_defaults" in config["presets"]:
|
||||
config["presets"].pop("inference_defaults")
|
||||
|
||||
|
||||
if "embeddings_defaults" in config["presets"]:
|
||||
config["presets"].pop("embeddings_defaults")
|
||||
|
||||
|
||||
# for normal presets we only want to persist if they have changed
|
||||
for preset_name, preset in list(config["presets"]["inference"].items()):
|
||||
if not preset.get("changed"):
|
||||
config["presets"]["inference"].pop(preset_name)
|
||||
|
||||
|
||||
# in inference groups also only keep if changed
|
||||
for group_name, group in list(config["presets"]["inference_groups"].items()):
|
||||
for preset_name, preset in list(group["presets"].items()):
|
||||
@@ -616,7 +621,7 @@ def save_config(config, file_path: str = "./config.yaml"):
|
||||
# if presets is empty, remove it
|
||||
if not config["presets"]["inference"]:
|
||||
config["presets"].pop("inference")
|
||||
|
||||
|
||||
# if system_prompts is empty, remove it
|
||||
if not config["system_prompts"]:
|
||||
config.pop("system_prompts")
|
||||
@@ -624,12 +629,13 @@ def save_config(config, file_path: str = "./config.yaml"):
|
||||
# set any client preset_group to "" if it references an
|
||||
# entry that no longer exists in inference_groups
|
||||
for client in config["clients"].values():
|
||||
|
||||
if not client.get("preset_group"):
|
||||
continue
|
||||
|
||||
|
||||
if client["preset_group"] not in config["presets"].get("inference_groups", {}):
|
||||
log.warning(f"Client {client['name']} references non-existent preset group {client['preset_group']}, setting to default")
|
||||
log.warning(
|
||||
f"Client {client['name']} references non-existent preset group {client['preset_group']}, setting to default"
|
||||
)
|
||||
client["preset_group"] = ""
|
||||
|
||||
with open(file_path, "w") as file:
|
||||
@@ -639,7 +645,6 @@ def save_config(config, file_path: str = "./config.yaml"):
|
||||
|
||||
|
||||
def cleanup() -> Config:
|
||||
|
||||
log.info("cleaning up config")
|
||||
|
||||
config = load_config(as_model=True)
|
||||
|
||||
@@ -35,11 +35,12 @@ regeneration_context = ContextVar("regeneration_context", default=None)
|
||||
active_scene = ContextVar("active_scene", default=None)
|
||||
interaction = ContextVar("interaction", default=InteractionState())
|
||||
|
||||
|
||||
def handle_generation_cancelled(exc: GenerationCancelled):
|
||||
# set cancel_requested to False on the active_scene
|
||||
|
||||
|
||||
scene = active_scene.get()
|
||||
|
||||
|
||||
if scene:
|
||||
scene.cancel_requested = False
|
||||
|
||||
@@ -102,6 +103,6 @@ class Interaction:
|
||||
def assert_active_scene(scene: object):
|
||||
if not active_scene.get():
|
||||
raise SceneInactiveError("Scene is not active")
|
||||
|
||||
|
||||
if active_scene.get() != scene:
|
||||
raise SceneInactiveError("Scene has changed")
|
||||
raise SceneInactiveError("Scene has changed")
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import talemate.emit.signals as signals
|
||||
import talemate.emit.signals as signals # noqa: F401
|
||||
|
||||
from .base import (
|
||||
AbortCommand,
|
||||
Emission,
|
||||
Emitter,
|
||||
Receiver,
|
||||
abort_wait_for_input,
|
||||
emit,
|
||||
wait_for_input,
|
||||
wait_for_input_yesno,
|
||||
)
|
||||
AbortCommand, # noqa: F401
|
||||
Emission, # noqa: F401
|
||||
Emitter, # noqa: F401
|
||||
Receiver, # noqa: F401
|
||||
abort_wait_for_input, # noqa: F401
|
||||
emit, # noqa: F401
|
||||
wait_for_input, # noqa: F401
|
||||
wait_for_input_yesno, # noqa: F401
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ __all__ = [
|
||||
|
||||
handlers = {}
|
||||
|
||||
|
||||
class AsyncSignal:
|
||||
def __init__(self, name):
|
||||
self.receivers = []
|
||||
@@ -21,11 +22,11 @@ class AsyncSignal:
|
||||
self.receivers.remove(handler)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
async def send(self, emission):
|
||||
for receiver in self.receivers:
|
||||
await receiver(emission)
|
||||
|
||||
|
||||
|
||||
def _register(name: str):
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -123,16 +123,16 @@ async def wait_for_input(
|
||||
|
||||
while input_received["message"] is None:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
|
||||
interaction_state = interaction.get()
|
||||
|
||||
|
||||
if abort_condition and (await abort_condition()):
|
||||
raise AbortWaitForInput()
|
||||
|
||||
|
||||
if interaction_state.reset_requested:
|
||||
interaction_state.reset_requested = False
|
||||
raise RestartSceneLoop()
|
||||
|
||||
|
||||
if interaction_state.input:
|
||||
input_received["message"] = interaction_state.input
|
||||
input_received["interaction"] = interaction_state
|
||||
@@ -140,7 +140,7 @@ async def wait_for_input(
|
||||
interaction_state.input = None
|
||||
interaction_state.from_choice = None
|
||||
break
|
||||
|
||||
|
||||
handlers["receive_input"].disconnect(input_receiver)
|
||||
|
||||
if input_received["message"] == "!abort":
|
||||
@@ -194,6 +194,6 @@ class Emitter:
|
||||
|
||||
def player_message(self, message: str, character: Character):
|
||||
self.emit("player", message, character=character)
|
||||
|
||||
|
||||
def context_investigation_message(self, message: str):
|
||||
self.emit("context_investigation", message)
|
||||
|
||||
@@ -31,6 +31,7 @@ class ArchiveEvent(Event):
|
||||
memory_id: str
|
||||
ts: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CharacterStateEvent(Event):
|
||||
state: str
|
||||
@@ -62,11 +63,13 @@ class GameLoopActorIterEvent(GameLoopBase):
|
||||
actor: Actor
|
||||
game_loop: GameLoopEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameLoopCharacterIterEvent(GameLoopBase):
|
||||
character: Character
|
||||
game_loop: GameLoopEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameLoopNewMessageEvent(GameLoopBase):
|
||||
message: SceneMessage
|
||||
@@ -76,6 +79,7 @@ class GameLoopNewMessageEvent(GameLoopBase):
|
||||
class PlayerTurnStartEvent(Event):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegenerateGeneration(Event):
|
||||
message: "SceneMessage"
|
||||
@@ -89,4 +93,4 @@ async_signals.register(
|
||||
"regenerate.msg.context_investigation",
|
||||
"game_loop_player_character_iter",
|
||||
"game_loop_ai_character_iter",
|
||||
)
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user