diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 015f0f74..5075c93e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,11 @@ jobs: source .venv/bin/activate uv pip install -e ".[dev]" + - name: Run linting + run: | + source .venv/bin/activate + uv run pre-commit run --all-files + - name: Setup configuration file run: | cp config.example.yaml config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..3030a46f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +fail_fast: false +exclude: | + (?x)^( + tests/data/.* + |install-utils/.* + )$ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.12.1 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/docs/cleanup.py b/docs/cleanup.py index 233427a9..25af3167 100644 --- a/docs/cleanup.py +++ b/docs/cleanup.py @@ -1,60 +1,63 @@ import os import re import subprocess -from pathlib import Path import argparse + def find_image_references(md_file): """Find all image references in a markdown file.""" - with open(md_file, 'r', encoding='utf-8') as f: + with open(md_file, "r", encoding="utf-8") as f: content = f.read() - - pattern = r'!\[.*?\]\((.*?)\)' + + pattern = r"!\[.*?\]\((.*?)\)" matches = re.findall(pattern, content) - + cleaned_paths = [] for match in matches: - path = match.lstrip('/') - if 'img/' in path: - path = path[path.index('img/') + 4:] + path = match.lstrip("/") + if "img/" in path: + path = path[path.index("img/") + 4 :] # Only keep references to versioned images parts = os.path.normpath(path).split(os.sep) - if len(parts) >= 2 and parts[0].replace('.', '').isdigit(): + if len(parts) >= 2 and parts[0].replace(".", "").isdigit(): cleaned_paths.append(path) - + return cleaned_paths + def scan_markdown_files(docs_dir): """Recursively scan all markdown files in the docs directory.""" md_files = [] for root, _, files in os.walk(docs_dir): for file in files: - if file.endswith('.md'): + if file.endswith(".md"): md_files.append(os.path.join(root, file)) return md_files + def find_all_images(img_dir): """Find all image files in version subdirectories.""" image_files = [] for root, _, files in os.walk(img_dir): # Get the relative path from img_dir to current directory rel_dir = os.path.relpath(root, img_dir) - + # Skip if we're in the root img directory - if rel_dir == '.': + if rel_dir == ".": continue - + # Check if the immediate parent directory is a version number parent_dir = rel_dir.split(os.sep)[0] - if not parent_dir.replace('.', '').isdigit(): + if not parent_dir.replace(".", "").isdigit(): continue - + for file in files: - if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.svg')): + if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".svg")): rel_path = os.path.relpath(os.path.join(root, file), img_dir) image_files.append(rel_path) return image_files + def grep_check_image(docs_dir, image_path): """ Check if versioned image is referenced anywhere using grep. @@ -65,33 +68,46 @@ def grep_check_image(docs_dir, image_path): parts = os.path.normpath(image_path).split(os.sep) version = parts[0] # e.g., "0.29.0" filename = parts[-1] # e.g., "world-state-suggestions-2.png" - + # For versioned images, require both version and filename to match version_pattern = f"{version}.*{filename}" try: result = subprocess.run( - ['grep', '-r', '-l', version_pattern, docs_dir], + ["grep", "-r", "-l", version_pattern, docs_dir], capture_output=True, - text=True + text=True, ) if result.stdout.strip(): - print(f"Found reference to {image_path} with version pattern: {version_pattern}") + print( + f"Found reference to {image_path} with version pattern: {version_pattern}" + ) return True except subprocess.CalledProcessError: pass - + except Exception as e: print(f"Error during grep check for {image_path}: {e}") - + return False + def main(): - parser = argparse.ArgumentParser(description='Find and optionally delete unused versioned images in MkDocs project') - parser.add_argument('--docs-dir', type=str, required=True, help='Path to the docs directory') - parser.add_argument('--img-dir', type=str, required=True, help='Path to the images directory') - parser.add_argument('--delete', action='store_true', help='Delete unused images') - parser.add_argument('--verbose', action='store_true', help='Show all found references and files') - parser.add_argument('--skip-grep', action='store_true', help='Skip the additional grep validation') + parser = argparse.ArgumentParser( + description="Find and optionally delete unused versioned images in MkDocs project" + ) + parser.add_argument( + "--docs-dir", type=str, required=True, help="Path to the docs directory" + ) + parser.add_argument( + "--img-dir", type=str, required=True, help="Path to the images directory" + ) + parser.add_argument("--delete", action="store_true", help="Delete unused images") + parser.add_argument( + "--verbose", action="store_true", help="Show all found references and files" + ) + parser.add_argument( + "--skip-grep", action="store_true", help="Skip the additional grep validation" + ) args = parser.parse_args() # Convert paths to absolute paths @@ -118,7 +134,7 @@ def main(): print("\nAll versioned image references found in markdown:") for img in sorted(used_images): print(f"- {img}") - + print("\nAll versioned images in directory:") for img in sorted(all_images): print(f"- {img}") @@ -133,9 +149,11 @@ def main(): for img in unused_images: if not grep_check_image(docs_dir, img): actually_unused.add(img) - + if len(actually_unused) != len(unused_images): - print(f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!") + print( + f"\nGrep validation found {len(unused_images) - len(actually_unused)} additional image references!" + ) unused_images = actually_unused # Report findings @@ -148,7 +166,7 @@ def main(): print("\nUnused versioned images:") for img in sorted(unused_images): print(f"- {img}") - + if args.delete: print("\nDeleting unused versioned images...") for img in unused_images: @@ -162,5 +180,6 @@ def main(): else: print("\nNo unused versioned images found!") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/docs/dev/agents/example/test/__init__.py b/docs/dev/agents/example/test/__init__.py index f0ef7e68..b3224e74 100644 --- a/docs/dev/agents/example/test/__init__.py +++ b/docs/dev/agents/example/test/__init__.py @@ -4,12 +4,12 @@ from talemate.events import GameLoopEvent import talemate.emit.async_signals from talemate.emit import emit + @register() class TestAgent(Agent): - agent_type = "test" verbose_name = "Test" - + def __init__(self, client): self.client = client self.is_enabled = True @@ -20,7 +20,7 @@ class TestAgent(Agent): description="Test", ), } - + @property def enabled(self): return self.is_enabled @@ -36,7 +36,7 @@ class TestAgent(Agent): def connect(self, scene): super().connect(scene) talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop) - + async def on_game_loop(self, emission: GameLoopEvent): """ Called on the beginning of every game loop @@ -45,4 +45,8 @@ class TestAgent(Agent): if not self.enabled: return - emit("status", status="info", message="Annoying you with a test message every game loop.") \ No newline at end of file + emit( + "status", + status="info", + message="Annoying you with a test message every game loop.", + ) diff --git a/docs/dev/client/example/runpod_vllm/__init__.py b/docs/dev/client/example/runpod_vllm/__init__.py index 3aaca612..0ce63813 100644 --- a/docs/dev/client/example/runpod_vllm/__init__.py +++ b/docs/dev/client/example/runpod_vllm/__init__.py @@ -19,14 +19,17 @@ from talemate.config import Client as BaseClientConfig log = structlog.get_logger("talemate.client.runpod_vllm") + class Defaults(pydantic.BaseModel): max_token_length: int = 4096 model: str = "" runpod_id: str = "" - + + class ClientConfig(BaseClientConfig): runpod_id: str = "" + @register() class RunPodVLLMClient(ClientBase): client_type = "runpod_vllm" @@ -49,7 +52,6 @@ class RunPodVLLMClient(ClientBase): ) } - def __init__(self, model=None, runpod_id=None, **kwargs): self.model_name = model self.runpod_id = runpod_id @@ -59,12 +61,10 @@ class RunPodVLLMClient(ClientBase): def experimental(self): return False - def set_client(self, **kwargs): log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id) self.runpod_id = kwargs.get("runpod_id", self.runpod_id) - - + def tune_prompt_parameters(self, parameters: dict, kind: str): super().tune_prompt_parameters(parameters, kind) @@ -88,32 +88,37 @@ class RunPodVLLMClient(ClientBase): self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters) try: - async with aiohttp.ClientSession() as session: endpoint = runpod.AsyncioEndpoint(self.runpod_id, session) - - run_request = await endpoint.run({ - "input": { - "prompt": prompt, + + run_request = await endpoint.run( + { + "input": { + "prompt": prompt, + } + # "parameters": parameters } - #"parameters": parameters - }) - - while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]: + ) + + while (await run_request.status()) not in [ + "COMPLETED", + "FAILED", + "CANCELLED", + ]: status = await run_request.status() log.debug("generate", status=status) await asyncio.sleep(0.1) - + status = await run_request.status() - + log.debug("generate", status=status) - + response = await run_request.output() - + log.debug("generate", response=response) - + return response["choices"][0]["tokens"][0] - + except Exception as e: self.log.error("generate error", e=e) emit( diff --git a/docs/dev/client/example/test/__init__.py b/docs/dev/client/example/test/__init__.py index 449e9eec..30bb80bf 100644 --- a/docs/dev/client/example/test/__init__.py +++ b/docs/dev/client/example/test/__init__.py @@ -9,6 +9,7 @@ class Defaults(pydantic.BaseModel): api_url: str = "http://localhost:1234" max_token_length: int = 4096 + @register() class TestClient(ClientBase): client_type = "test" @@ -22,14 +23,13 @@ class TestClient(ClientBase): self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111") def tune_prompt_parameters(self, parameters: dict, kind: str): - """ Talemate adds a bunch of parameters to the prompt, but not all of them are valid for all clients. - + This method is called before the prompt is sent to the client, and it allows the client to remove any parameters that it doesn't support. """ - + super().tune_prompt_parameters(parameters, kind) keys = list(parameters.keys()) @@ -41,11 +41,10 @@ class TestClient(ClientBase): del parameters[key] async def get_model_name(self): - """ This should return the name of the model that is being used. """ - + return "Mock test model" async def generate(self, prompt: str, parameters: dict, kind: str): diff --git a/frontend_wsgi.py b/frontend_wsgi.py index 6a0e873f..21768b5c 100644 --- a/frontend_wsgi.py +++ b/frontend_wsgi.py @@ -1,6 +1,6 @@ import os -from fastapi import FastAPI, Request -from fastapi.responses import HTMLResponse, FileResponse +from fastapi import FastAPI +from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from starlette.exceptions import HTTPException @@ -14,6 +14,7 @@ app = FastAPI() # Serve static files, but exclude the root path app.mount("/", StaticFiles(directory=dist_dir, html=True), name="static") + @app.get("/", response_class=HTMLResponse) async def serve_root(): index_path = os.path.join(dist_dir, "index.html") @@ -24,5 +25,6 @@ async def serve_root(): else: raise HTTPException(status_code=404, detail="index.html not found") + # This is the ASGI application -application = app \ No newline at end of file +application = app diff --git a/pyproject.toml b/pyproject.toml index b7d5afbe..be67b0a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dev = [ "mkdocs-material>=9.5.27", "mkdocs-awesome-pages-plugin>=2.9.2", "mkdocs-glightbox>=0.4.0", + "pre-commit>=2.13", ] [project.scripts] @@ -103,4 +104,4 @@ include_trailing_comma = true force_grid_wrap = 0 use_parentheses = true ensure_newline_before_comments = true -line_length = 88 \ No newline at end of file +line_length = 88 diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..c2c07179 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,5 @@ +[lint] +# Disable automatic fix for unused imports (`F401`). We check these manually. +unfixable = ["F401"] +# Ignore E402 +extend-ignore = ["E402"] diff --git a/scenes/simulation-suite/game.py b/scenes/simulation-suite/game.py index 0352456e..aa2628df 100644 --- a/scenes/simulation-suite/game.py +++ b/scenes/simulation-suite/game.py @@ -1,111 +1,112 @@ - def game(TM): - MSG_PROCESSED_INSTRUCTIONS = "Simulation suite processed instructions" - - MSG_HELP = "Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating \"Computer,\" followed by an instruction. For example ... \"Computer, i want to experience being on a derelict spaceship.\"" - + + MSG_HELP = 'Instructions to the simulation computer are only processed if the computer is directly addressed at the beginning of the instruction. Please state your commands by addressing the computer by stating "Computer," followed by an instruction. For example ... "Computer, i want to experience being on a derelict spaceship."' + PROMPT_NARRATE_ROUND = "Narrate the simulation and reveal some new details to the player in one paragraph. YOU MUST NOT ADDRESS THE COMPUTER OR THE SIMULATION." - + PROMPT_STARTUP = "Narrate the computer asking the user to state the nature of their desired simulation in a synthetic and soft sounding voice. Remind the user that this is an old version of the simulation suite and they should check out version two for a more advanced experience." - - CTX_PIN_UNAWARE = "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION." - + + CTX_PIN_UNAWARE = ( + "Characters in the simulation ARE NOT AWARE OF THE COMPUTER OR THE SIMULATION." + ) + AUTO_NARRATE_INTERVAL = 10 - - def parse_sim_call_arguments(call:str) -> str: + + def parse_sim_call_arguments(call: str) -> str: """ Returns the value between the parentheses of a simulation call - + Example: - + call = 'change_environment("a house")' - + parse_sim_call_arguments(call) -> "a house" """ - + try: return call.split("(", 1)[1].split(")")[0] except Exception: return "" - + class SimulationSuite: def __init__(self): - """ This is initialized at the beginning of each round of the simulation suite """ - + # do we update the world state at the end of the round self.update_world_state = False self.simulation_reset = False - + # will keep track of any npcs added during the current round self.added_npcs = [] - + TM.log.debug("SIMULATION SUITE INIT!", scene=TM.scene) self.player_message = TM.scene.last_player_message - self.last_processed_call = TM.game_state.get_var("instr.lastprocessed_call", -1) - + self.last_processed_call = TM.game_state.get_var( + "instr.lastprocessed_call", -1 + ) + # determine whether the player / user input is an instruction # to the simulation computer - # + # # we do this by checking if the message starts with "Computer," self.player_message_is_instruction = ( - self.player_message and - self.player_message.raw.lower().startswith("computer") and - not self.player_message.hidden and - not self.last_processed_call > self.player_message.id + self.player_message + and self.player_message.raw.lower().startswith("computer") + and not self.player_message.hidden + and not self.last_processed_call > self.player_message.id ) - + def run(self): """ Main entry point for the simulation suite """ - + if not TM.game_state.has_var("instr.simulation_stopped"): # simulation is still running self.simulation() - + self.finalize_round() - + def simulation(self): """ Simulation suite logic """ - + if not TM.game_state.has_var("instr.simulation_started"): self.startup() else: self.simulation_calls() - + if self.update_world_state: self.run_update_world_state(force=True) - + def startup(self): - """ Scene startup logic """ - + # we are at the beginning of the simulation - TM.signals.status("busy", "Simulation suite powering up.", as_scene_message=True) + TM.signals.status( + "busy", "Simulation suite powering up.", as_scene_message=True + ) TM.game_state.set_var("instr.simulation_started", "yes", commit=False) - + # add narration for the introduction TM.agents.narrator.action_to_narration( action_name="progress_story", narrative_direction=PROMPT_STARTUP, - emit_message=False + emit_message=False, ) - + # add narration for the instructions on how to interact with the simulation # this is a passthrough since we don't want the AI to paraphrase this TM.agents.narrator.action_to_narration( - action_name="passthrough", - narration=MSG_HELP + action_name="passthrough", narration=MSG_HELP ) - + # create a world state entry letting the AI know that characters # interacting in the simulation are not aware of the computer or the simulation TM.agents.world_state.save_world_entry( @@ -113,37 +114,43 @@ def game(TM): text=CTX_PIN_UNAWARE, meta={}, # this should always be pinned - pin=True + pin=True, ) - + # set flag that we have started the simulation TM.game_state.set_var("instr.simulation_started", "yes", commit=False) - + # signal to the UX that the simulation suite is ready - TM.signals.status("success", "Simulation suite ready", as_scene_message=True) - + TM.signals.status( + "success", "Simulation suite ready", as_scene_message=True + ) + # we want to update the world state at the end of the round self.update_world_state = True - + def simulation_calls(self): """ Calls the simulation suite main prompt to determine the appropriate simulation calls """ - + # we only process instructions that are not hidden and are not the last processed call - if not self.player_message_is_instruction or self.player_message.id == self.last_processed_call: + if ( + not self.player_message_is_instruction + or self.player_message.id == self.last_processed_call + ): return - + # First instruction? if not TM.game_state.has_var("instr.has_issued_instructions"): - # determine the context of the simulation - context_context = TM.agents.creator.determine_content_context_for_description( - description=self.player_message.raw, + context_context = ( + TM.agents.creator.determine_content_context_for_description( + description=self.player_message.raw, + ) ) TM.scene.set_content_context(context_context) - + # Render the `computer` template and send it to the LLM for processing # The LLM will return a list of calls that the simulation suite will process # The calls are pseudo code that the simulation suite will interpret and execute @@ -153,90 +160,98 @@ def game(TM): player_instruction=self.player_message.raw, scene=TM.scene, ) - + self.calls = calls = calls.split("\n") - + calls = self.prepare_calls(calls) - + TM.log.debug("SIMULATION SUITE CALLS", callse=calls) - + # calls that are processed processed = [] - + for call in calls: processed_call = self.process_call(call) if processed_call: processed.append(processed_call) - if processed: TM.log.debug("SIMULATION SUITE CALLS", calls=processed) - TM.game_state.set_var("instr.has_issued_instructions", "yes", commit=False) - - TM.signals.status("busy", "Simulation suite altering environment.", as_scene_message=True) + TM.game_state.set_var( + "instr.has_issued_instructions", "yes", commit=False + ) + + TM.signals.status( + "busy", "Simulation suite altering environment.", as_scene_message=True + ) compiled = "\n".join(processed) - + if not self.simulation_reset and compiled: - # send the compiled calls to the narrator to generate a narrative based # on them narration = TM.agents.narrator.action_to_narration( action_name="progress_story", narrative_direction=f"The computer calls the following functions:\n\n```\n{compiled}\n```\n\nand the simulation adjusts the environment according to the user's wishes.\n\nWrite the narrative that describes the changes to the player in the context of the simulation starting up. YOU MUST NOT REFERENCE THE COMPUTER OR THE SIMULATION.", - emit_message=True + emit_message=True, ) - + # on the first narration we update the scene description and remove any mention of the computer # or the simulation from the previous narration - is_initial_narration = TM.game_state.get_var("instr.intro_narration", False) + is_initial_narration = TM.game_state.get_var( + "instr.intro_narration", False + ) if not is_initial_narration: TM.scene.set_description(narration.raw) TM.scene.set_intro(narration.raw) - TM.log.debug("SIMULATION SUITE: initial narration", intro=narration.raw) + TM.log.debug( + "SIMULATION SUITE: initial narration", intro=narration.raw + ) TM.scene.pop_history(typ="narrator", all=True, reverse=True) TM.scene.pop_history(typ="director", all=True, reverse=True) TM.game_state.set_var("instr.intro_narration", True, commit=False) - + self.update_world_state = True - + self.set_simulation_title(compiled) - def set_simulation_title(self, compiled_calls): - """ Generates a fitting title for the simulation based on the user's instructions """ - - TM.log.debug("SIMULATION SUITE: set simulation title", name=TM.scene.title, compiled_calls=compiled_calls) - + + TM.log.debug( + "SIMULATION SUITE: set simulation title", + name=TM.scene.title, + compiled_calls=compiled_calls, + ) + if not compiled_calls: return - + if TM.scene.title != "Simulation Suite": # name already changed, no need to do it again return - + title = TM.agents.creator.contextual_generate_from_args( "scene:simulation title", "Create a fitting title for the simulated scenario that the user has requested. You response MUST be a short but exciting, descriptive title.", - length=75 + length=75, ) - + title = title.strip('"').strip() - + TM.scene.set_title(title) - + def prepare_calls(self, calls): """ Loops through calls and if a `set_player_name` call and a `set_player_persona` call are both found, ensure that the `set_player_name` call is processed first by moving it in front of the `set_player_persona` call. """ - + set_player_name_call_exists = -1 set_player_persona_call_exists = -1 - + i = 0 for call in calls: if "set_player_name" in call: @@ -244,351 +259,445 @@ def game(TM): elif "set_player_persona" in call: set_player_persona_call_exists = i i = i + 1 - - if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1: + if set_player_name_call_exists > -1 and set_player_persona_call_exists > -1: if set_player_name_call_exists > set_player_persona_call_exists: - calls.insert(set_player_persona_call_exists, calls.pop(set_player_name_call_exists)) - TM.log.debug("SIMULATION SUITE: prepare calls - moved set_player_persona call", calls=calls) - + calls.insert( + set_player_persona_call_exists, + calls.pop(set_player_name_call_exists), + ) + TM.log.debug( + "SIMULATION SUITE: prepare calls - moved set_player_persona call", + calls=calls, + ) + return calls - def process_call(self, call:str) -> str: + def process_call(self, call: str) -> str: """ Processes a simulation call - + Simulation alls are pseudo functions that are called by the simulation suite - + We grab the function name by splitting against ( and taking the first element if the SimulationSuite has a method with the name _call_{function_name} then we call it - + if a function name could be found but we do not have a method to call we dont do anything but we still return it as procssed as the AI can still interpret it as something later on """ - + if "(" not in call: return None - + function_name = call.split("(")[0] - + if hasattr(self, f"call_{function_name}"): - TM.log.debug("SIMULATION SUITE CALL", call=call, function_name=function_name) - + TM.log.debug( + "SIMULATION SUITE CALL", call=call, function_name=function_name + ) + inject = f"The computer executes the function `{call}`" - + return getattr(self, f"call_{function_name}")(call, inject) - + return call - def call_set_simulation_goal(self, call:str, inject:str) -> str: + def call_set_simulation_goal(self, call: str, inject: str) -> str: """ Set's the simulation goal as a permanent pin """ - TM.signals.status("busy", "Simulation suite setting goal.", as_scene_message=True) - TM.agents.world_state.save_world_entry( - entry_id="sim.goal", - text=self.player_message.raw, - meta={}, - pin=True + TM.signals.status( + "busy", "Simulation suite setting goal.", as_scene_message=True ) - + TM.agents.world_state.save_world_entry( + entry_id="sim.goal", text=self.player_message.raw, meta={}, pin=True + ) + TM.agents.director.log_action( action=parse_sim_call_arguments(call), action_description="The computer sets the goal for the simulation.", ) - + return call - - def call_change_environment(self, call:str, inject:str) -> str: + + def call_change_environment(self, call: str, inject: str) -> str: """ Simulation changes the environment, this is entirely interpreted by the AI and we dont need to do any logic on our end, so we just return the call """ - + TM.agents.director.log_action( action=parse_sim_call_arguments(call), - action_description="The computer changes the environment of the simulation." + action_description="The computer changes the environment of the simulation.", ) - + return call - - - def call_answer_question(self, call:str, inject:str) -> str: + + def call_answer_question(self, call: str, inject: str) -> str: """ The player asked the simulation a query, we need to process this and have the AI produce an answer """ - + TM.agents.narrator.action_to_narration( action_name="progress_story", narrative_direction=f"The computer calls the following function:\n\n{call}\n\nand answers the player's question.", - emit_message=True + emit_message=True, ) - - - def call_set_player_persona(self, call:str, inject:str) -> str: - + + def call_set_player_persona(self, call: str, inject: str) -> str: """ The simulation suite is altering the player persona """ - + player_character = TM.scene.get_player_character() - - TM.signals.status("busy", "Simulation suite altering user persona.", as_scene_message=True) - character_attributes = TM.agents.world_state.extract_character_sheet( - name=player_character.name, text=inject, alteration_instructions=self.player_message.raw + + TM.signals.status( + "busy", "Simulation suite altering user persona.", as_scene_message=True ) - TM.scene.set_character_attributes(player_character.name, character_attributes) - - character_description = TM.agents.creator.determine_character_description(player_character.name) - - TM.scene.set_character_description(player_character.name, character_description) - - TM.log.debug("SIMULATION SUITE: transform player", attributes=character_attributes, description=character_description) - + character_attributes = TM.agents.world_state.extract_character_sheet( + name=player_character.name, + text=inject, + alteration_instructions=self.player_message.raw, + ) + TM.scene.set_character_attributes( + player_character.name, character_attributes + ) + + character_description = TM.agents.creator.determine_character_description( + player_character.name + ) + + TM.scene.set_character_description( + player_character.name, character_description + ) + + TM.log.debug( + "SIMULATION SUITE: transform player", + attributes=character_attributes, + description=character_description, + ) + TM.agents.director.log_action( action=parse_sim_call_arguments(call), - action_description="The computer transforms the player persona." + action_description="The computer transforms the player persona.", ) - + return call - - def call_set_player_name(self, call:str, inject:str) -> str: - + + def call_set_player_name(self, call: str, inject: str) -> str: """ The simulation suite is altering the player name """ player_character = TM.scene.get_player_character() - - TM.signals.status("busy", "Simulation suite adjusting user identity.", as_scene_message=True) - character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits.") + + TM.signals.status( + "busy", + "Simulation suite adjusting user identity.", + as_scene_message=True, + ) + character_name = TM.agents.creator.determine_character_name( + instructions=f"{inject} - What is a fitting name for the player persona? Respond with the current name if it still fits." + ) TM.log.debug("SIMULATION SUITE: player name", character_name=character_name) if character_name != player_character.name: TM.scene.set_character_name(player_character.name, character_name) - + TM.agents.director.log_action( action=parse_sim_call_arguments(call), - action_description=f"The computer changes the player's identity to {character_name}." + action_description=f"The computer changes the player's identity to {character_name}.", ) - - return call - - def call_add_ai_character(self, call:str, inject:str) -> str: - + return call + + def call_add_ai_character(self, call: str, inject: str) -> str: # sometimes the AI will call this function an pass an inanimate object as the parameter # we need to determine if this is the case and just ignore it - is_inanimate = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", call) - + is_inanimate = TM.agents.world_state.answer_query_true_or_false( + f"does the function `{call}` add an inanimate object, concept or abstract idea? (ANYTHING THAT IS NOT A CHARACTER THAT COULD BE PORTRAYED BY AN ACTOR)", + call, + ) + if is_inanimate: - TM.log.debug("SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", call=call) + TM.log.debug( + "SIMULATION SUITE: add npc - inanimate object / abstact idea - skipped", + call=call, + ) return - + # sometimes the AI will ask if the function adds a group of characters, we need to # determine if this is the case - adds_group = TM.agents.world_state.answer_query_true_or_false(f"does the function `{call}` add MULTIPLE ai characters?", call) - + adds_group = TM.agents.world_state.answer_query_true_or_false( + f"does the function `{call}` add MULTIPLE ai characters?", call + ) + TM.log.debug("SIMULATION SUITE: add npc", adds_group=adds_group) - - TM.signals.status("busy", "Simulation suite adding character.", as_scene_message=True) - + + TM.signals.status( + "busy", "Simulation suite adding character.", as_scene_message=True + ) + if not adds_group: - character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.") + character_name = TM.agents.creator.determine_character_name( + instructions=f"{inject} - what is the name of the character to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name." + ) else: - character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", group=True) - + character_name = TM.agents.creator.determine_character_name( + instructions=f"{inject} - what is the name of the group of characters to be added to the scene? If no name can extracted from the text, extract a short descriptive name instead. Respond only with the name.", + group=True, + ) + # sometimes add_ai_character and change_ai_character are called in the same instruction targeting # the same character, if this happens we need to combine into a single add_ai_character call - - has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false(f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", "\n".join(self.calls)) - + + has_change_ai_character_call = TM.agents.world_state.answer_query_true_or_false( + f"Are there any calls to `change_ai_character` in the instruction for {character_name}?", + "\n".join(self.calls), + ) + if has_change_ai_character_call: - - combined_arg = TM.prompt.request( - "combine-add-and-alter-ai-character", - dedupe_enabled=False, - calls="\n".join(self.calls), - character_name=character_name, - scene=TM.scene, - ).replace("COMBINED ARGUMENT:", "").strip() - + combined_arg = ( + TM.prompt.request( + "combine-add-and-alter-ai-character", + dedupe_enabled=False, + calls="\n".join(self.calls), + character_name=character_name, + scene=TM.scene, + ) + .replace("COMBINED ARGUMENT:", "") + .strip() + ) + call = f"add_ai_character({combined_arg})" inject = f"The computer executes the function `{call}`" - - - TM.signals.status("busy", f"Simulation suite adding character: {character_name}", as_scene_message=True) - + + TM.signals.status( + "busy", + f"Simulation suite adding character: {character_name}", + as_scene_message=True, + ) + TM.log.debug("SIMULATION SUITE: add npc", name=character_name) - - npc = TM.agents.director.persist_character(character_name=character_name, content=self.player_message.raw+f"\n\n{inject}", determine_name=False) - + + npc = TM.agents.director.persist_character( + character_name=character_name, + content=self.player_message.raw + f"\n\n{inject}", + determine_name=False, + ) + self.added_npcs.append(npc.name) - + TM.agents.world_state.add_detail_reinforcement( character_name=npc.name, detail="Goal", instructions=f"Generate a goal for {npc.name}, based on the user's chosen simulation", interval=25, - run_immediately=True + run_immediately=True, ) - + TM.log.debug("SIMULATION SUITE: added npc", npc=npc) - + TM.agents.visual.generate_character_portrait(character_name=npc.name) - + TM.agents.director.log_action( action=parse_sim_call_arguments(call), - action_description=f"The computer adds {npc.name} to the simulation." + action_description=f"The computer adds {npc.name} to the simulation.", ) - - return call + + return call #### - def call_remove_ai_character(self, call:str, inject:str) -> str: - TM.signals.status("busy", "Simulation suite removing character.", as_scene_message=True) - - character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character being removed?", allowed_names=TM.scene.npc_character_names) - + def call_remove_ai_character(self, call: str, inject: str) -> str: + TM.signals.status( + "busy", "Simulation suite removing character.", as_scene_message=True + ) + + character_name = TM.agents.creator.determine_character_name( + instructions=f"{inject} - what is the name of the character being removed?", + allowed_names=TM.scene.npc_character_names, + ) + npc = TM.scene.get_character(character_name) - + if npc: TM.log.debug("SIMULATION SUITE: remove npc", npc=npc.name) - TM.agents.world_state.deactivate_character(action_name="deactivate_character", character_name=npc.name) - + TM.agents.world_state.deactivate_character( + action_name="deactivate_character", character_name=npc.name + ) + TM.agents.director.log_action( action=parse_sim_call_arguments(call), - action_description=f"The computer removes {npc.name} from the simulation." + action_description=f"The computer removes {npc.name} from the simulation.", ) - + return call - def call_change_ai_character(self, call:str, inject:str) -> str: - TM.signals.status("busy", "Simulation suite altering character.", as_scene_message=True) - - character_name = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?", allowed_names=TM.scene.npc_character_names) - + def call_change_ai_character(self, call: str, inject: str) -> str: + TM.signals.status( + "busy", "Simulation suite altering character.", as_scene_message=True + ) + + character_name = TM.agents.creator.determine_character_name( + instructions=f"{inject} - what is the name of the character receiving the changes (before the change)?", + allowed_names=TM.scene.npc_character_names, + ) + if character_name in self.added_npcs: # we dont want to change the character if it was just added return - - character_name_after = TM.agents.creator.determine_character_name(instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?") - + + character_name_after = TM.agents.creator.determine_character_name( + instructions=f"{inject} - what is the name of the character receiving the changes (after the changes)?" + ) + npc = TM.scene.get_character(character_name) - + if npc: - TM.signals.status("busy", f"Changing {character_name} -> {character_name_after}", as_scene_message=True) - + TM.signals.status( + "busy", + f"Changing {character_name} -> {character_name_after}", + as_scene_message=True, + ) + TM.log.debug("SIMULATION SUITE: transform npc", npc=npc) - + character_attributes = TM.agents.world_state.extract_character_sheet( name=npc.name, text=inject, - alteration_instructions=self.player_message.raw + alteration_instructions=self.player_message.raw, ) - + TM.scene.set_character_attributes(npc.name, character_attributes) - character_description = TM.agents.creator.determine_character_description(npc.name) - + character_description = ( + TM.agents.creator.determine_character_description(npc.name) + ) + TM.scene.set_character_description(npc.name, character_description) - TM.log.debug("SIMULATION SUITE: transform npc", attributes=character_attributes, description=character_description) - + TM.log.debug( + "SIMULATION SUITE: transform npc", + attributes=character_attributes, + description=character_description, + ) + if character_name_after != character_name: TM.scene.set_character_name(npc.name, character_name_after) - + TM.agents.director.log_action( action=parse_sim_call_arguments(call), - action_description=f"The computer transforms {npc.name}." + action_description=f"The computer transforms {npc.name}.", ) - + return call - - def call_end_simulation(self, call:str, inject:str) -> str: - + + def call_end_simulation(self, call: str, inject: str) -> str: player_character = TM.scene.get_player_character() - - explicit_command = TM.agents.world_state.answer_query_true_or_false("has the player explicitly asked to end the simulation?", self.player_message.raw) - + + explicit_command = TM.agents.world_state.answer_query_true_or_false( + "has the player explicitly asked to end the simulation?", + self.player_message.raw, + ) + if explicit_command: - TM.signals.status("busy", "Simulation suite ending current simulation.", as_scene_message=True) + TM.signals.status( + "busy", + "Simulation suite ending current simulation.", + as_scene_message=True, + ) TM.agents.narrator.action_to_narration( action_name="progress_story", narrative_direction=f"Narrate the computer ending the simulation, dissolving the environment and all artificial characters, erasing all memory of it and finally returning the player to the inactive simulation suite. List of artificial characters: {', '.join(TM.scene.npc_character_names)}. The player is also transformed back to their normal, non-descript persona as the form of {player_character.name} ceases to exist.", - emit_message=True + emit_message=True, ) TM.scene.restore() self.simulation_reset = True - + TM.game_state.unset_var("instr.has_issued_instructions") TM.game_state.unset_var("instr.lastprocessed_call") TM.game_state.unset_var("instr.simulation_started") - + TM.agents.director.log_action( action=parse_sim_call_arguments(call), - action_description="The computer ends the simulation." + action_description="The computer ends the simulation.", ) - + def finalize_round(self): - # track rounds rounds = TM.game_state.get_var("instr.rounds", 0) - + # increase rounds TM.game_state.set_var("instr.rounds", rounds + 1, commit=False) - - has_issued_instructions = TM.game_state.has_var("instr.has_issued_instructions") - + + has_issued_instructions = TM.game_state.has_var( + "instr.has_issued_instructions" + ) + if self.update_world_state: self.run_update_world_state() - + if self.player_message_is_instruction: TM.scene.hide_message(self.player_message.id) - TM.game_state.set_var("instr.lastprocessed_call", self.player_message.id, commit=False) - TM.signals.status("success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True) - + TM.game_state.set_var( + "instr.lastprocessed_call", self.player_message.id, commit=False + ) + TM.signals.status( + "success", MSG_PROCESSED_INSTRUCTIONS, as_scene_message=True + ) + elif self.player_message and not has_issued_instructions: # simulation started, player message is NOT an instruction, and player has not given # any instructions self.guide_player() elif self.player_message and not TM.scene.npc_character_names: - # simulation started, player message is NOT an instruction, but there are no npcs to interact with + # simulation started, player message is NOT an instruction, but there are no npcs to interact with self.narrate_round() - - elif rounds % AUTO_NARRATE_INTERVAL == 0 and rounds and TM.scene.npc_character_names and has_issued_instructions: + + elif ( + rounds % AUTO_NARRATE_INTERVAL == 0 + and rounds + and TM.scene.npc_character_names + and has_issued_instructions + ): # every N rounds, narrate the round self.narrate_round() - + def guide_player(self): TM.agents.narrator.action_to_narration( - action_name="paraphrase", - narration=MSG_HELP, - emit_message=True + action_name="paraphrase", narration=MSG_HELP, emit_message=True ) - + def narrate_round(self): TM.agents.narrator.action_to_narration( action_name="progress_story", narrative_direction=PROMPT_NARRATE_ROUND, - emit_message=True + emit_message=True, ) - + def run_update_world_state(self, force=False): TM.log.debug("SIMULATION SUITE: update world state", force=force) - TM.signals.status("busy", "Simulation suite updating world state.", as_scene_message=True) + TM.signals.status( + "busy", "Simulation suite updating world state.", as_scene_message=True + ) TM.agents.world_state.update_world_state(force=force) - TM.signals.status("success", "Simulation suite updated world state.", as_scene_message=True) + TM.signals.status( + "success", + "Simulation suite updated world state.", + as_scene_message=True, + ) SimulationSuite().run() - + + def on_generation_cancelled(TM, exc): - """ Called when user pressed the cancel button during the simulation suite loop. """ - - TM.signals.status("success", "Simulation suite instructions cancelled", as_scene_message=True) + + TM.signals.status( + "success", "Simulation suite instructions cancelled", as_scene_message=True + ) rounds = TM.game_state.get_var("instr.rounds", 0) - TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds) \ No newline at end of file + TM.log.debug("SIMULATION SUITE: command cancelled", rounds=rounds) diff --git a/src/talemate/__init__.py b/src/talemate/__init__.py index e55cb9d6..431f1ae6 100644 --- a/src/talemate/__init__.py +++ b/src/talemate/__init__.py @@ -1,5 +1,5 @@ -from .tale_mate import * +from .tale_mate import * # noqa: F401, F403 from .version import VERSION -__version__ = VERSION \ No newline at end of file +__version__ = VERSION diff --git a/src/talemate/agents/__init__.py b/src/talemate/agents/__init__.py index 356ac7bf..b9c388e4 100644 --- a/src/talemate/agents/__init__.py +++ b/src/talemate/agents/__init__.py @@ -1,12 +1,12 @@ -from .base import Agent -from .conversation import ConversationAgent -from .creator import CreatorAgent -from .director import DirectorAgent -from .editor import EditorAgent -from .memory import ChromaDBMemoryAgent, MemoryAgent -from .narrator import NarratorAgent -from .registry import AGENT_CLASSES, get_agent_class, register -from .summarize import SummarizeAgent -from .tts import TTSAgent -from .visual import VisualAgent -from .world_state import WorldStateAgent +from .base import Agent # noqa: F401 +from .conversation import ConversationAgent # noqa: F401 +from .creator import CreatorAgent # noqa: F401 +from .director import DirectorAgent # noqa: F401 +from .editor import EditorAgent # noqa: F401 +from .memory import ChromaDBMemoryAgent, MemoryAgent # noqa: F401 +from .narrator import NarratorAgent # noqa: F401 +from .registry import AGENT_CLASSES, get_agent_class, register # noqa: F401 +from .summarize import SummarizeAgent # noqa: F401 +from .tts import TTSAgent # noqa: F401 +from .visual import VisualAgent # noqa: F401 +from .world_state import WorldStateAgent # noqa: F401 diff --git a/src/talemate/agents/base.py b/src/talemate/agents/base.py index adaf0d6b..b8b7d1d2 100644 --- a/src/talemate/agents/base.py +++ b/src/talemate/agents/base.py @@ -6,11 +6,10 @@ from inspect import signature import re from abc import ABC from functools import wraps -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import Callable, Union import uuid import pydantic import structlog -from blinker import signal import talemate.emit.async_signals import talemate.instance as instance @@ -39,6 +38,7 @@ __all__ = [ log = structlog.get_logger("talemate.agents.base") + class AgentActionConditional(pydantic.BaseModel): attribute: str value: int | float | str | bool | list[int | float | str | bool] | None = None @@ -48,6 +48,7 @@ class AgentActionNote(pydantic.BaseModel): type: str text: str + class AgentActionConfig(pydantic.BaseModel): type: str label: str @@ -65,7 +66,7 @@ class AgentActionConfig(pydantic.BaseModel): condition: Union[AgentActionConditional, None] = None title: Union[str, None] = None value_migration: Union[Callable, None] = pydantic.Field(default=None, exclude=True) - + note_on_value: dict[str, AgentActionNote] = pydantic.Field(default_factory=dict) class Config: @@ -85,37 +86,37 @@ class AgentAction(pydantic.BaseModel): quick_toggle: bool = False experimental: bool = False + class AgentDetail(pydantic.BaseModel): value: Union[str, None] = None description: Union[str, None] = None icon: Union[str, None] = None color: str = "grey" + class DynamicInstruction(pydantic.BaseModel): title: str content: str - + def __str__(self) -> str: return "\n".join( - [ - f"<|SECTION:{self.title}|>", - self.content, - "<|CLOSE_SECTION|>" - ] + [f"<|SECTION:{self.title}|>", self.content, "<|CLOSE_SECTION|>"] ) - -def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = None) -> dict: + +def args_and_kwargs_to_dict( + fn, args: list, kwargs: dict, filter: list[str] = None +) -> dict: """ Takes a list of arguments and a dict of keyword arguments and returns a dict mapping parameter names to their values. - + Args: fn: The function whose parameters we want to map args: List of positional arguments kwargs: Dictionary of keyword arguments filter: List of parameter names to include in the result, if None all parameters are included - + Returns: Dict mapping parameter names to their values """ @@ -124,34 +125,36 @@ def args_and_kwargs_to_dict(fn, args: list, kwargs: dict, filter:list[str] = Non bound_args.apply_defaults() rv = dict(bound_args.arguments) rv.pop("self", None) - + if filter: for key in list(rv.keys()): if key not in filter: rv.pop(key) - + return rv class store_context_state: """ Flag to store a function's arguments in the agent's context state. - + Any arguments passed to the function will be stored in the agent's context - + If no arguments are passed, all arguments will be stored. - + Keyword arguments can be passed to store additional values in the context state. """ + def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs - + def __call__(self, fn): fn.store_context_state = self.args fn.store_context_state_kwargs = self.kwargs return fn + def set_processing(fn): """ decorator that emits the agent status as processing while the function @@ -165,31 +168,38 @@ def set_processing(fn): async def wrapper(self, *args, **kwargs): with ClientContext(): scene = active_scene.get() - + if scene: scene.continue_actions() - + if getattr(scene, "config", None): - set_client_context_attribute("app_config_system_prompts", scene.config.get("system_prompts", {})) - + set_client_context_attribute( + "app_config_system_prompts", scene.config.get("system_prompts", {}) + ) + with ActiveAgent(self, fn, args, kwargs) as active_agent_context: try: await self.emit_status(processing=True) - + # Now pass the complete args list if getattr(fn, "store_context_state", None) is not None: all_args = args_and_kwargs_to_dict( - fn, [self] + list(args), kwargs, getattr(fn, "store_context_state", []) + fn, + [self] + list(args), + kwargs, + getattr(fn, "store_context_state", []), ) if getattr(fn, "store_context_state_kwargs", None) is not None: - all_args.update(getattr(fn, "store_context_state_kwargs", {})) - + all_args.update( + getattr(fn, "store_context_state_kwargs", {}) + ) + all_args[f"fn_{fn.__name__}"] = True - + active_agent_context.state_params = all_args - + self.set_context_states(**all_args) - + return await fn(self, *args, **kwargs) finally: try: @@ -215,14 +225,16 @@ class Agent(ABC): websocket_handler = None essential = True ready_check_error = None - + @classmethod - def init_actions(cls, actions: dict[str, AgentAction] | None = None) -> dict[str, AgentAction]: + def init_actions( + cls, actions: dict[str, AgentAction] | None = None + ) -> dict[str, AgentAction]: if actions is None: actions = {} - + return actions - + @property def agent_details(self): if hasattr(self, "client"): @@ -230,10 +242,6 @@ class Agent(ABC): return self.client.name return None - @property - def verbose_name(self): - return self.agent_type.capitalize() - @property def ready(self): if not getattr(self.client, "enabled", True): @@ -315,67 +323,68 @@ class Agent(ABC): return {} return {k: v.model_dump() for k, v in self.actions.items()} - + # scene state - - + def context_fingerpint(self, extra: list[str] = []) -> str | None: active_agent_context = active_agent.get() - + if not active_agent_context: return None - + if self.scene.history: fingerprint = f"{self.scene.history[-1].fingerprint}-{active_agent_context.first.fingerprint}" else: fingerprint = f"START-{active_agent_context.first.fingerprint}" - + for extra_key in extra: fingerprint += f"-{hash(extra_key)}" - + return fingerprint - - - def get_scene_state(self, key:str, default=None): + + def get_scene_state(self, key: str, default=None): agent_state = self.scene.agent_state.get(self.agent_type, {}) return agent_state.get(key, default) - + def set_scene_states(self, **kwargs): agent_state = self.scene.agent_state.get(self.agent_type, {}) for key, value in kwargs.items(): agent_state[key] = value self.scene.agent_state[self.agent_type] = agent_state - + def dump_scene_state(self): return self.scene.agent_state.get(self.agent_type, {}) - + # active agent context state - - def get_context_state(self, key:str, default=None): + + def get_context_state(self, key: str, default=None): key = f"{self.agent_type}__{key}" try: return active_agent.get().state.get(key, default) except AttributeError: log.warning("get_context_state error", agent=self.agent_type, key=key) return default - + def set_context_states(self, **kwargs): try: - items = {f"{self.agent_type}__{k}": v for k, v in kwargs.items()} active_agent.get().state.update(items) - log.debug("set_context_states", agent=self.agent_type, state=active_agent.get().state) + log.debug( + "set_context_states", + agent=self.agent_type, + state=active_agent.get().state, + ) except AttributeError: log.error("set_context_states error", agent=self.agent_type, kwargs=kwargs) - + def dump_context_state(self): try: return active_agent.get().state except AttributeError: return {} - + ### - + async def _handle_ready_check(self, fut: asyncio.Future): callback_failure = getattr(self, "on_ready_check_failure", None) if fut.cancelled(): @@ -425,41 +434,50 @@ class Agent(ABC): if not action.config: continue - for config_key, config in action.config.items(): + for config_key, _config in action.config.items(): try: - config.value = ( + _config.value = ( kwargs.get("actions", {}) .get(action_key, {}) .get("config", {}) .get(config_key, {}) - .get("value", config.value) + .get("value", _config.value) ) - if config.value_migration and callable(config.value_migration): - config.value = config.value_migration(config.value) + if _config.value_migration and callable(_config.value_migration): + _config.value = _config.value_migration(_config.value) except AttributeError: pass - async def save_config(self, app_config: config.Config | None = None): """ Saves the agent config to the config file. - + If no config object is provided, the config is loaded from the config file. """ - + if not app_config: - app_config:config.Config = config.load_config(as_model=True) - + app_config: config.Config = config.load_config(as_model=True) + app_config.agents[self.agent_type] = config.Agent( name=self.agent_type, client=self.client.name if self.client else None, enabled=self.enabled, - actions={action_key: config.AgentAction( - enabled=action.enabled, - config={config_key: config.AgentActionConfig(value=config_obj.value) for config_key, config_obj in action.config.items()} - ) for action_key, action in self.actions.items()} + actions={ + action_key: config.AgentAction( + enabled=action.enabled, + config={ + config_key: config.AgentActionConfig(value=config_obj.value) + for config_key, config_obj in action.config.items() + }, + ) + for action_key, action in self.actions.items() + }, + ) + log.debug( + "saving agent config", + agent=self.agent_type, + config=app_config.agents[self.agent_type], ) - log.debug("saving agent config", agent=self.agent_type, config=app_config.agents[self.agent_type]) config.save_config(app_config) async def on_game_loop_start(self, event: GameLoopStartEvent): @@ -474,19 +492,19 @@ class Agent(ABC): if not action.config: continue - for _, config in action.config.items(): - if config.scope == "scene": + for _, _config in action.config.items(): + if _config.scope == "scene": # if default_value is None, just use the `type` of the current # value - if config.default_value is None: - default_value = type(config.value)() + if _config.default_value is None: + default_value = type(_config.value)() else: - default_value = config.default_value + default_value = _config.default_value log.debug( - "resetting config", config=config, default_value=default_value + "resetting config", config=_config, default_value=default_value ) - config.value = default_value + _config.value = default_value await self.emit_status() @@ -518,7 +536,9 @@ class Agent(ABC): await asyncio.sleep(0.01) - async def _handle_background_processing(self, fut: asyncio.Future, error_handler = None): + async def _handle_background_processing( + self, fut: asyncio.Future, error_handler=None + ): try: if fut.cancelled(): return @@ -541,7 +561,7 @@ class Agent(ABC): self.processing_bg -= 1 await self.emit_status() - async def set_background_processing(self, task: asyncio.Task, error_handler = None): + async def set_background_processing(self, task: asyncio.Task, error_handler=None): log.info("set_background_processing", agent=self.agent_type) if not hasattr(self, "processing_bg"): self.processing_bg = 0 @@ -550,7 +570,9 @@ class Agent(ABC): await self.emit_status() task.add_done_callback( - lambda fut: asyncio.create_task(self._handle_background_processing(fut, error_handler)) + lambda fut: asyncio.create_task( + self._handle_background_processing(fut, error_handler) + ) ) def connect(self, scene): @@ -623,7 +645,6 @@ class Agent(ABC): """ return False - @set_processing async def delegate(self, fn: Callable, *args, **kwargs): """ @@ -631,20 +652,22 @@ class Agent(ABC): by the agent. """ return await fn(*args, **kwargs) - - async def emit_message(self, header:str, message:str | list[dict], meta: dict = None, **data): + + async def emit_message( + self, header: str, message: str | list[dict], meta: dict = None, **data + ): if not data: data = {} - + if not meta: meta = {} - + if "uuid" not in data: data["uuid"] = str(uuid.uuid4()) - + if "agent" not in data: data["agent"] = self.agent_type - + data["header"] = header emit( "agent_message", @@ -653,17 +676,22 @@ class Agent(ABC): meta=meta, websocket_passthrough=True, ) - + + @dataclasses.dataclass class AgentEmission: agent: Agent + @dataclasses.dataclass class AgentTemplateEmission(AgentEmission): template_vars: dict = dataclasses.field(default_factory=dict) response: str = None - dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list) - + dynamic_instructions: list[DynamicInstruction] = dataclasses.field( + default_factory=list + ) + + @dataclasses.dataclass class RagBuildSubInstructionEmission(AgentEmission): sub_instruction: str | None = None diff --git a/src/talemate/agents/context.py b/src/talemate/agents/context.py index ec88438e..953e7baa 100644 --- a/src/talemate/agents/context.py +++ b/src/talemate/agents/context.py @@ -1,13 +1,9 @@ import contextvars import uuid -import hashlib -from typing import TYPE_CHECKING, Callable +from typing import Callable import pydantic -if TYPE_CHECKING: - from talemate.tale_mate import Character - __all__ = [ "active_agent", ] @@ -25,7 +21,6 @@ class ActiveAgentContext(pydantic.BaseModel): state: dict = pydantic.Field(default_factory=dict) state_params: dict = pydantic.Field(default_factory=dict) previous: "ActiveAgentContext" = None - class Config: arbitrary_types_allowed = True @@ -35,29 +30,30 @@ class ActiveAgentContext(pydantic.BaseModel): return self.previous.first if self.previous else self @property - def action(self): + def action(self): name = self.fn.__name__ if name == "delegate": return self.fn_args[0].__name__ return name - + @property def fingerprint(self) -> int: if hasattr(self, "_fingerprint"): return self._fingerprint self._fingerprint = hash(frozenset(self.state_params.items())) return self._fingerprint - + def __str__(self): return f"{self.agent.verbose_name}.{self.action}" - + class ActiveAgent: def __init__(self, agent, fn, args=None, kwargs=None): - self.agent = ActiveAgentContext(agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {}) + self.agent = ActiveAgentContext( + agent=agent, fn=fn, fn_args=args or tuple(), fn_kwargs=kwargs or {} + ) def __enter__(self): - previous_agent = active_agent.get() if previous_agent: @@ -70,7 +66,7 @@ class ActiveAgent: self.agent.agent_stack_uid = str(uuid.uuid4()) self.token = active_agent.set(self.agent) - + return self.agent def __exit__(self, *args, **kwargs): diff --git a/src/talemate/agents/conversation/__init__.py b/src/talemate/agents/conversation/__init__.py index 7677eea1..29ae195d 100644 --- a/src/talemate/agents/conversation/__init__.py +++ b/src/talemate/agents/conversation/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import random import re from datetime import datetime from typing import TYPE_CHECKING, Optional @@ -10,14 +9,12 @@ import structlog import talemate.client as client import talemate.emit.async_signals -import talemate.instance as instance import talemate.util as util from talemate.client.context import ( client_context_attribute, set_client_context_attribute, set_conversation_context_attribute, ) -from talemate.events import GameLoopEvent from talemate.exceptions import LLMAccuracyError from talemate.prompts import Prompt from talemate.scene_message import CharacterMessage, DirectorMessage @@ -37,7 +34,7 @@ from talemate.agents.memory.rag import MemoryRAGMixin from talemate.agents.context import active_agent from .websocket_handler import ConversationWebsocketHandler -import talemate.agents.conversation.nodes +import talemate.agents.conversation.nodes # noqa: F401 if TYPE_CHECKING: from talemate.tale_mate import Actor, Character @@ -50,21 +47,20 @@ class ConversationAgentEmission(AgentEmission): actor: Actor character: Character response: str - dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list) + dynamic_instructions: list[DynamicInstruction] = dataclasses.field( + default_factory=list + ) talemate.emit.async_signals.register( - "agent.conversation.before_generate", + "agent.conversation.before_generate", "agent.conversation.inject_instructions", - "agent.conversation.generated" + "agent.conversation.generated", ) @register() -class ConversationAgent( - MemoryRAGMixin, - Agent -): +class ConversationAgent(MemoryRAGMixin, Agent): """ An agent that can be used to have a conversation with the AI @@ -135,8 +131,6 @@ class ConversationAgent( max=20, step=1, ), - - }, ), "auto_break_repetition": AgentAction( @@ -159,7 +153,7 @@ class ConversationAgent( description="Use the writing style selected in the scene settings", value=True, ), - } + }, ), } MemoryRAGMixin.add_actions(actions) @@ -205,7 +199,6 @@ class ConversationAgent( @property def agent_details(self) -> dict: - details = { "client": AgentDetail( icon="mdi-network-outline", @@ -231,22 +224,24 @@ class ConversationAgent( @property def generation_settings_actor_instructions_offset(self): - return self.actions["generation_override"].config["actor_instructions_offset"].value - + return ( + self.actions["generation_override"] + .config["actor_instructions_offset"] + .value + ) + @property def generation_settings_response_length(self): return self.actions["generation_override"].config["length"].value - + @property def generation_settings_override_enabled(self): return self.actions["generation_override"].enabled - + @property def content_use_writing_style(self) -> bool: return self.actions["content"].config["use_writing_style"].value - - def connect(self, scene): super().connect(scene) @@ -276,7 +271,7 @@ class ConversationAgent( main_character = scene.main_character.character character_names = [c.name for c in scene.characters] - + if main_character: try: character_names.remove(main_character.name) @@ -296,21 +291,22 @@ class ConversationAgent( director_message = isinstance(scene_and_dialogue[-1], DirectorMessage) except IndexError: director_message = False - - + inject_instructions_emission = ConversationAgentEmission( agent=self, - response="", - actor=None, - character=character, + response="", + actor=None, + character=character, ) await talemate.emit.async_signals.get( "agent.conversation.inject_instructions" ).send(inject_instructions_emission) - + agent_context = active_agent.get() - agent_context.state["dynamic_instructions"] = inject_instructions_emission.dynamic_instructions - + agent_context.state["dynamic_instructions"] = ( + inject_instructions_emission.dynamic_instructions + ) + conversation_format = self.conversation_format prompt = Prompt.get( f"conversation.dialogue-{conversation_format}", @@ -319,26 +315,30 @@ class ConversationAgent( "max_tokens": self.client.max_token_length, "scene_and_dialogue_budget": scene_and_dialogue_budget, "scene_and_dialogue": scene_and_dialogue, - "memory": None, # DEPRECATED VARIABLE + "memory": None, # DEPRECATED VARIABLE "characters": list(scene.get_characters()), "main_character": main_character, "formatted_names": formatted_names, "talking_character": character, "partial_message": char_message, "director_message": director_message, - "extra_instructions": self.generation_settings_task_instructions, #backward compatibility + "extra_instructions": self.generation_settings_task_instructions, # backward compatibility "task_instructions": self.generation_settings_task_instructions, "actor_instructions": self.generation_settings_actor_instructions, "actor_instructions_offset": self.generation_settings_actor_instructions_offset, "direct_instruction": instruction, "decensor": self.client.decensor_enabled, - "response_length": self.generation_settings_response_length if self.generation_settings_override_enabled else None, + "response_length": self.generation_settings_response_length + if self.generation_settings_override_enabled + else None, }, ) return str(prompt) - async def build_prompt(self, character, char_message: str = "", instruction:str = None): + async def build_prompt( + self, character, char_message: str = "", instruction: str = None + ): fn = self.build_prompt_default return await fn(character, char_message=char_message, instruction=instruction) @@ -376,12 +376,12 @@ class ConversationAgent( set_client_context_attribute("nuke_repetition", nuke_repetition) @set_processing - @store_context_state('instruction') + @store_context_state("instruction") async def converse( - self, + self, actor, - instruction:str = None, - emit_signals:bool = True, + instruction: str = None, + emit_signals: bool = True, ) -> list[CharacterMessage]: """ Have a conversation with the AI @@ -398,7 +398,9 @@ class ConversationAgent( self.set_generation_overrides() - result = await self.client.send_prompt(await self.build_prompt(character, instruction=instruction)) + result = await self.client.send_prompt( + await self.build_prompt(character, instruction=instruction) + ) result = self.clean_result(result, character) @@ -454,7 +456,7 @@ class ConversationAgent( # movie script format # {uppercase character name} # {dialogue} - total_result = total_result.replace(f"{character.name.upper()}\n", f"") + total_result = total_result.replace(f"{character.name.upper()}\n", "") # chat format # {character name}: {dialogue} @@ -464,7 +466,7 @@ class ConversationAgent( total_result = util.clean_dialogue(total_result, main_name=character.name) # Check if total_result starts with character name, if not, prepend it - if not total_result.startswith(character.name+":"): + if not total_result.startswith(character.name + ":"): total_result = f"{character.name}: {total_result}" total_result = total_result.strip() @@ -481,11 +483,11 @@ class ConversationAgent( log.debug("conversation agent", response=response) emission = ConversationAgentEmission( - agent=self, - actor=actor, - character=character, + agent=self, + actor=actor, + character=character, response=response, - ) + ) if emit_signals: await talemate.emit.async_signals.get("agent.conversation.generated").send( emission diff --git a/src/talemate/agents/conversation/nodes.py b/src/talemate/agents/conversation/nodes.py index 655e0e7f..4ef71654 100644 --- a/src/talemate/agents/conversation/nodes.py +++ b/src/talemate/agents/conversation/nodes.py @@ -1,74 +1,75 @@ import structlog from typing import TYPE_CHECKING, ClassVar -from talemate.game.engine.nodes.core import Node, GraphState, UNRESOLVED +from talemate.game.engine.nodes.core import GraphState, UNRESOLVED from talemate.game.engine.nodes.registry import register from talemate.game.engine.nodes.agent import AgentNode, AgentSettingsNode from talemate.context import active_scene from talemate.client.context import ConversationContext, ClientContext -import talemate.events as events if TYPE_CHECKING: from talemate.tale_mate import Scene, Character log = structlog.get_logger("talemate.game.engine.nodes.agents.conversation") + @register("agents/conversation/Settings") class ConversationSettings(AgentSettingsNode): """ Base node to render conversation agent settings. """ - _agent_name:ClassVar[str] = "conversation" - + + _agent_name: ClassVar[str] = "conversation" + def __init__(self, title="Conversation Settings", **kwargs): super().__init__(title=title, **kwargs) + @register("agents/conversation/Generate") class GenerateConversation(AgentNode): """ Generate a conversation between two characters """ - - _agent_name:ClassVar[str] = "conversation" - + + _agent_name: ClassVar[str] = "conversation" + def __init__(self, title="Generate Conversation", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character", socket_type="character") self.add_input("instruction", socket_type="str", optional=True) - + self.set_property("trigger_conversation_generated", True) - + self.add_output("generated", socket_type="str") self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): - character:"Character" = self.get_input_value("character") - scene:"Scene" = active_scene.get() + character: "Character" = self.get_input_value("character") + scene: "Scene" = active_scene.get() instruction = self.get_input_value("instruction") - trigger_conversation_generated = self.get_property("trigger_conversation_generated") - + trigger_conversation_generated = self.get_property( + "trigger_conversation_generated" + ) + other_characters = [c.name for c in scene.characters if c != character] - + conversation_context = ConversationContext( talking_character=character.name, other_characters=other_characters, ) - + if instruction == UNRESOLVED: instruction = None - + with ClientContext(conversation=conversation_context): messages = await self.agent.converse( - character.actor, + character.actor, instruction=instruction, emit_signals=trigger_conversation_generated, ) - + message = messages[0] - - self.set_output_values({ - "generated": message.message, - "message": message - }) \ No newline at end of file + + self.set_output_values({"generated": message.message, "message": message}) diff --git a/src/talemate/agents/conversation/websocket_handler.py b/src/talemate/agents/conversation/websocket_handler.py index 3458aa70..a7ec9d88 100644 --- a/src/talemate/agents/conversation/websocket_handler.py +++ b/src/talemate/agents/conversation/websocket_handler.py @@ -18,23 +18,25 @@ __all__ = [ log = structlog.get_logger("talemate.server.conversation") + class RequestActorActionPayload(pydantic.BaseModel): - character:str = "" - instructions:str = "" - emit_signals:bool = True - instructions_through_director:bool = True + character: str = "" + instructions: str = "" + emit_signals: bool = True + instructions_through_director: bool = True + class ConversationWebsocketHandler(Plugin): """ Handles narrator actions """ - + router = "conversation" - + @property def agent(self) -> "ConversationAgent": return get_agent("conversation") - + @set_loading("Generating actor action", cancellable=True, as_async=True) async def handle_request_actor_action(self, data: dict): """ @@ -43,38 +45,37 @@ class ConversationWebsocketHandler(Plugin): payload = RequestActorActionPayload(**data) character = None actor = None - + if payload.character: character = self.scene.get_character(payload.character) actor = character.actor else: actor = random.choice(list(self.scene.get_npc_characters())).actor - + if not actor: log.error("handle_request_actor_action: No actor found") return - + character = actor.character - + if payload.instructions_through_director: director_message = DirectorMessage( payload.instructions, source="player", - meta={"character": character.name} + meta={"character": character.name}, ) emit("director", message=director_message, character=character) self.scene.push_history(director_message) generated_messages = await self.agent.converse( - actor, - emit_signals=payload.emit_signals + actor, emit_signals=payload.emit_signals ) else: generated_messages = await self.agent.converse( - actor, + actor, instruction=payload.instructions, - emit_signals=payload.emit_signals + emit_signals=payload.emit_signals, ) - + for message in generated_messages: self.scene.push_history(message) - emit("character", message=message, character=character) \ No newline at end of file + emit("character", message=message, character=character) diff --git a/src/talemate/agents/creator/__init__.py b/src/talemate/agents/creator/__init__.py index 0fcef2cc..de68bd0f 100644 --- a/src/talemate/agents/creator/__init__.py +++ b/src/talemate/agents/creator/__init__.py @@ -1,13 +1,9 @@ from __future__ import annotations -import json -import os - import talemate.client as client from talemate.agents.base import Agent, set_processing from talemate.agents.registry import register from talemate.agents.memory.rag import MemoryRAGMixin -from talemate.emit import emit from talemate.prompts import Prompt from .assistant import AssistantMixin @@ -16,7 +12,8 @@ from .scenario import ScenarioCreatorMixin from talemate.agents.base import AgentAction -import talemate.agents.creator.nodes +import talemate.agents.creator.nodes # noqa: F401 + @register() class CreatorAgent( @@ -51,7 +48,7 @@ class CreatorAgent( @set_processing async def generate_title(self, text: str): title = await Prompt.request( - f"creator.generate-title", + "creator.generate-title", self.client, "create_short", vars={ diff --git a/src/talemate/agents/creator/assistant.py b/src/talemate/agents/creator/assistant.py index 5ebe8ca3..0dbc28d3 100644 --- a/src/talemate/agents/creator/assistant.py +++ b/src/talemate/agents/creator/assistant.py @@ -40,6 +40,7 @@ async_signals.register( "agent.creator.autocomplete.after", ) + @dataclasses.dataclass class ContextualGenerateEmission(AgentTemplateEmission): """ @@ -48,15 +49,16 @@ class ContextualGenerateEmission(AgentTemplateEmission): content_generation_context: "ContentGenerationContext | None" = None character: "Character | None" = None - + @property def context_type(self) -> str: return self.content_generation_context.computed_context[0] - + @property def context_name(self) -> str: return self.content_generation_context.computed_context[1] + @dataclasses.dataclass class AutocompleteEmission(AgentTemplateEmission): """ @@ -67,6 +69,7 @@ class AutocompleteEmission(AgentTemplateEmission): type: str = "" character: "Character | None" = None + class ContentGenerationContext(pydantic.BaseModel): """ A context for generating content. @@ -104,7 +107,6 @@ class ContentGenerationContext(pydantic.BaseModel): @property def spice(self) -> str: - spice_level = self.generation_options.spice_level if self.template and not getattr(self.template, "supports_spice", False): @@ -148,7 +150,6 @@ class ContentGenerationContext(pydantic.BaseModel): @property def style(self): - if self.template and not getattr(self.template, "supports_style", False): # template supplied that doesn't support style return "" @@ -165,11 +166,12 @@ class ContentGenerationContext(pydantic.BaseModel): def get_state(self, key: str) -> str | int | float | bool | None: return self.state.get(key) + class AssistantMixin: """ Creator mixin that allows quick contextual generation of content. """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): actions["autocomplete"] = AgentAction( @@ -198,15 +200,15 @@ class AssistantMixin: max=256, step=16, ), - } + }, ) - + # property helpers - + @property def autocomplete_dialogue_suggestion_length(self): return self.actions["autocomplete"].config["dialogue_suggestion_length"].value - + @property def autocomplete_narrative_suggestion_length(self): return self.actions["autocomplete"].config["narrative_suggestion_length"].value @@ -255,7 +257,7 @@ class AssistantMixin: history_aware=history_aware, information=information, ) - + for key, value in kwargs.items(): generation_context.set_state(key, value) @@ -279,9 +281,13 @@ class AssistantMixin: f"Contextual generate: {context_typ} - {context_name}", generation_context=generation_context, ) - - character = self.scene.get_character(generation_context.character) if generation_context.character else None - + + character = ( + self.scene.get_character(generation_context.character) + if generation_context.character + else None + ) + template_vars = { "scene": self.scene, "max_tokens": self.client.max_token_length, @@ -295,27 +301,29 @@ class AssistantMixin: "character": character, "template": generation_context.template, } - + emission = ContextualGenerateEmission( agent=self, content_generation_context=generation_context, character=character, template_vars=template_vars, ) - - await async_signals.get("agent.creator.contextual_generate.before").send(emission) + + await async_signals.get("agent.creator.contextual_generate.before").send( + emission + ) template_vars["dynamic_instructions"] = emission.dynamic_instructions content = await Prompt.request( - f"creator.contextual-generate", + "creator.contextual-generate", self.client, kind, vars=template_vars, ) - + emission.response = content - + if not generation_context.partial: content = util.strip_partial_sentences(content) @@ -329,22 +337,29 @@ class AssistantMixin: if not content.startswith(generation_context.character + ":"): content = generation_context.character + ": " + content content = util.strip_partial_sentences(content) - - character = self.scene.get_character(generation_context.character) - - if not character: - log.warning("Character not found", character=generation_context.character) - return content - - emission.response = await editor.cleanup_character_message(content, character) - await async_signals.get("agent.creator.contextual_generate.after").send(emission) - return emission.response - - emission.response = content.strip().strip("*").strip() - - await async_signals.get("agent.creator.contextual_generate.after").send(emission) - return emission.response + character = self.scene.get_character(generation_context.character) + + if not character: + log.warning( + "Character not found", character=generation_context.character + ) + return content + + emission.response = await editor.cleanup_character_message( + content, character + ) + await async_signals.get("agent.creator.contextual_generate.after").send( + emission + ) + return emission.response + + emission.response = content.strip().strip("*").strip() + + await async_signals.get("agent.creator.contextual_generate.after").send( + emission + ) + return emission.response @set_processing async def generate_character_attribute( @@ -357,10 +372,10 @@ class AssistantMixin: ) -> str: """ Wrapper for contextual_generate that generates a character attribute. - """ + """ if not generation_options: generation_options = GenerationOptions() - + return await self.contextual_generate_from_args( context=f"character attribute:{attribute_name}", character=character.name, @@ -368,7 +383,7 @@ class AssistantMixin: original=original, **generation_options.model_dump(), ) - + @set_processing async def generate_character_detail( self, @@ -381,11 +396,11 @@ class AssistantMixin: ) -> str: """ Wrapper for contextual_generate that generates a character detail. - """ + """ if not generation_options: generation_options = GenerationOptions() - + return await self.contextual_generate_from_args( context=f"character detail:{detail_name}", character=character.name, @@ -394,7 +409,7 @@ class AssistantMixin: length=length, **generation_options.model_dump(), ) - + @set_processing async def generate_thematic_list( self, @@ -408,11 +423,11 @@ class AssistantMixin: """ if not generation_options: generation_options = GenerationOptions() - + i = 0 - + result = [] - + while i < iterations: i += 1 _result = await self.contextual_generate_from_args( @@ -420,14 +435,14 @@ class AssistantMixin: instructions=instructions, length=length, original="\n".join(result) if result else None, - extend=i>1, + extend=i > 1, **generation_options.model_dump(), ) - + _result = json.loads(_result) - + result = list(set(result + _result)) - + return result @set_processing @@ -443,34 +458,34 @@ class AssistantMixin: """ if not response_length: response_length = self.autocomplete_dialogue_suggestion_length - + # continuing recent character message non_anchor, anchor = util.split_anchor_text(input, 10) - + self.scene.log.debug( - "autocomplete_anchor", - anchor=anchor, - non_anchor=non_anchor, - input=input + "autocomplete_anchor", anchor=anchor, non_anchor=non_anchor, input=input ) continuing_message = False message = None - + try: message = self.scene.history[-1] - if isinstance(message, CharacterMessage) and message.character_name == character.name: + if ( + isinstance(message, CharacterMessage) + and message.character_name == character.name + ): continuing_message = input.strip() == message.without_name.strip() except IndexError: pass if input.strip().endswith('"'): - prefix = ' *' - elif input.strip().endswith('*'): + prefix = " *" + elif input.strip().endswith("*"): prefix = ' "' else: - prefix = '' - + prefix = "" + template_vars = { "scene": self.scene, "max_tokens": self.client.max_token_length, @@ -484,7 +499,7 @@ class AssistantMixin: "non_anchor": non_anchor, "prefix": prefix, } - + emission = AutocompleteEmission( agent=self, input=input, @@ -492,13 +507,13 @@ class AssistantMixin: character=character, template_vars=template_vars, ) - + await async_signals.get("agent.creator.autocomplete.before").send(emission) template_vars["dynamic_instructions"] = emission.dynamic_instructions response = await Prompt.request( - f"creator.autocomplete-dialogue", + "creator.autocomplete-dialogue", self.client, f"create_{response_length}", vars=template_vars, @@ -506,21 +521,22 @@ class AssistantMixin: dedupe_enabled=False, ) - response = response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "") - - + response = ( + response.replace("...", "").lstrip("").rstrip().replace("END-OF-LINE", "") + ) + if prefix: response = prefix + response - + emission.response = response - + await async_signals.get("agent.creator.autocomplete.after").send(emission) if not response: if emit_signal: emit("autocomplete_suggestion", "") return "" - + response = util.strip_partial_sentences(response).replace("*", "") if response.startswith(input): @@ -550,14 +566,14 @@ class AssistantMixin: # Split the input text into non-anchor and anchor parts non_anchor, anchor = util.split_anchor_text(input, 10) - + self.scene.log.debug( - "autocomplete_narrative_anchor", + "autocomplete_narrative_anchor", anchor=anchor, non_anchor=non_anchor, - input=input + input=input, ) - + template_vars = { "scene": self.scene, "max_tokens": self.client.max_token_length, @@ -567,20 +583,20 @@ class AssistantMixin: "anchor": anchor, "non_anchor": non_anchor, } - + emission = AutocompleteEmission( agent=self, input=input, type="narrative", template_vars=template_vars, ) - + await async_signals.get("agent.creator.autocomplete.before").send(emission) template_vars["dynamic_instructions"] = emission.dynamic_instructions response = await Prompt.request( - f"creator.autocomplete-narrative", + "creator.autocomplete-narrative", self.client, f"create_{response_length}", vars=template_vars, @@ -593,7 +609,7 @@ class AssistantMixin: response = response[len(input) :] emission.response = response - + await async_signals.get("agent.creator.autocomplete.after").send(emission) self.scene.log.debug( @@ -614,78 +630,75 @@ class AssistantMixin: """ Allows to fork a new scene from a specific message in the current scene. - + All content after the message will be removed and the context database will be re imported ensuring a clean state. - + All state reinforcements will be reset to their most recent state before the message. """ - + emit("status", "Creating scene fork ...", status="busy") try: if not save_name: # build a save name uuid_str = str(uuid.uuid4())[:8] save_name = f"{uuid_str}-forked" - - log.info(f"Forking scene", message_id=message_id, save_name=save_name) - + + log.info("Forking scene", message_id=message_id, save_name=save_name) + world_state = get_agent("world_state") - + # does a message with the given id exist? index = self.scene.message_index(message_id) if index is None: raise ValueError(f"Message with id {message_id} not found.") - + # truncate scene.history keeping index as the last element - self.scene.history = self.scene.history[:index + 1] - + self.scene.history = self.scene.history[: index + 1] + # truncate scene.archived_history keeping the element where `end` is < `index` # as the last element self.scene.archived_history = [ - x for x in self.scene.archived_history if "end" not in x or x["end"] < index + x + for x in self.scene.archived_history + if "end" not in x or x["end"] < index ] - + # the same needs to be done for layered history # where each layer is truncated based on what's left in the previous layer # using similar logic as above (checking `end` vs `index`) # layer 0 checks archived_history - + new_layered_history = [] for layer_number, layer in enumerate(self.scene.layered_history): - if layer_number == 0: index = len(self.scene.archived_history) - 1 else: index = len(new_layered_history[layer_number - 1]) - 1 - - new_layer = [ - x for x in layer if x["end"] < index - ] + + new_layer = [x for x in layer if x["end"] < index] new_layered_history.append(new_layer) - + self.scene.layered_history = new_layered_history # save the scene await self.scene.save(copy_name=save_name) - - log.info(f"Scene forked", save_name=save_name) - + + log.info("Scene forked", save_name=save_name) + # re-emit history await self.scene.emit_history() - - emit("status", f"Updating world state ...", status="busy") + + emit("status", "Updating world state ...", status="busy") # reset state reinforcements - await world_state.update_reinforcements(force = True, reset= True) - + await world_state.update_reinforcements(force=True, reset=True) + # update world state await self.scene.world_state.request_update() - - emit("status", f"Scene forked", status="success") - except Exception as e: + + emit("status", "Scene forked", status="success") + except Exception: log.error("Scene fork failed", exc=traceback.format_exc()) emit("status", "Scene fork failed", status="error") - - \ No newline at end of file diff --git a/src/talemate/agents/creator/character.py b/src/talemate/agents/creator/character.py index 7d3b81a1..e7f3cd35 100644 --- a/src/talemate/agents/creator/character.py +++ b/src/talemate/agents/creator/character.py @@ -7,8 +7,6 @@ import structlog from talemate.agents.base import set_processing from talemate.prompts import Prompt -import talemate.game.focal as focal - if TYPE_CHECKING: from talemate.tale_mate import Character @@ -18,14 +16,13 @@ DEFAULT_CONTENT_CONTEXT = "a fun and engaging adventure aimed at an adult audien class CharacterCreatorMixin: - @set_processing async def determine_content_context_for_character( self, character: Character, ): content_context = await Prompt.request( - f"creator.determine-content-context", + "creator.determine-content-context", self.client, "create_192", vars={ @@ -42,7 +39,7 @@ class CharacterCreatorMixin: information: str = "", ): instructions = await Prompt.request( - f"creator.determine-character-dialogue-instructions", + "creator.determine-character-dialogue-instructions", self.client, "create_concise", vars={ @@ -63,7 +60,7 @@ class CharacterCreatorMixin: character: Character, ): attributes = await Prompt.request( - f"creator.determine-character-attributes", + "creator.determine-character-attributes", self.client, "analyze_long", vars={ @@ -81,7 +78,7 @@ class CharacterCreatorMixin: instructions: str = "", ) -> str: name = await Prompt.request( - f"creator.determine-character-name", + "creator.determine-character-name", self.client, "analyze_freeform_short", vars={ @@ -97,14 +94,14 @@ class CharacterCreatorMixin: @set_processing async def determine_character_description( - self, + self, character: Character, text: str = "", instructions: str = "", information: str = "", ): description = await Prompt.request( - f"creator.determine-character-description", + "creator.determine-character-description", self.client, "create", vars={ @@ -125,7 +122,7 @@ class CharacterCreatorMixin: goal_instructions: str, ): goals = await Prompt.request( - f"creator.determine-character-goals", + "creator.determine-character-goals", self.client, "create", vars={ @@ -141,4 +138,4 @@ class CharacterCreatorMixin: log.debug("determine_character_goals", goals=goals, character=character) await character.set_detail("goals", goals.strip()) - return goals.strip() \ No newline at end of file + return goals.strip() diff --git a/src/talemate/agents/creator/nodes.py b/src/talemate/agents/creator/nodes.py index 4ee8ab2d..951b331f 100644 --- a/src/talemate/agents/creator/nodes.py +++ b/src/talemate/agents/creator/nodes.py @@ -1,145 +1,153 @@ import structlog -from typing import ClassVar, TYPE_CHECKING +from typing import ClassVar from talemate.context import active_scene -from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES +from talemate.game.engine.nodes.core import ( + GraphState, + PropertyField, + UNRESOLVED, +) from talemate.game.engine.nodes.registry import register from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode -if TYPE_CHECKING: - from talemate.tale_mate import Scene - log = structlog.get_logger("talemate.game.engine.nodes.agents.creator") + @register("agents/creator/Settings") class CreatorSettings(AgentSettingsNode): """ Base node to render creator agent settings. """ - _agent_name:ClassVar[str] = "creator" - + + _agent_name: ClassVar[str] = "creator" + def __init__(self, title="Creator Settings", **kwargs): super().__init__(title=title, **kwargs) - - + + @register("agents/creator/DetermineContentContext") class DetermineContentContext(AgentNode): """ Determines the context for the content creation. """ - _agent_name:ClassVar[str] = "creator" - + + _agent_name: ClassVar[str] = "creator" + class Fields: description = PropertyField( name="description", description="Description of the context", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + def __init__(self, title="Determine Content Context", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("description", socket_type="str", optional=True) - + self.set_property("description", UNRESOLVED) - + self.add_output("content_context", socket_type="str") - + async def run(self, state: GraphState): context = await self.agent.determine_content_context_for_description( self.require_input("description") ) - self.set_output_values({ - "content_context": context - }) - + self.set_output_values({"content_context": context}) + + @register("agents/creator/DetermineCharacterDescription") class DetermineCharacterDescription(AgentNode): """ Determines the description for a character. - + Inputs: - + - state: The current state of the graph - character: The character to determine the description for - extra_context: Extra context to use in determining the - + Outputs: - + - description: The determined description """ - _agent_name:ClassVar[str] = "creator" - + + _agent_name: ClassVar[str] = "creator" + def __init__(self, title="Determine Character Description", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character", socket_type="character") self.add_input("extra_context", socket_type="str", optional=True) - + self.add_output("description", socket_type="str") - + async def run(self, state: GraphState): - character = self.require_input("character") extra_context = self.get_input_value("extra_context") - + if extra_context is UNRESOLVED: extra_context = "" - - description = await self.agent.determine_character_description(character, extra_context) - - self.set_output_values({ - "description": description - }) - + + description = await self.agent.determine_character_description( + character, extra_context + ) + + self.set_output_values({"description": description}) + + @register("agents/creator/DetermineCharacterDialogueInstructions") class DetermineCharacterDialogueInstructions(AgentNode): """ Determines the dialogue instructions for a character. """ - _agent_name:ClassVar[str] = "creator" - + + _agent_name: ClassVar[str] = "creator" + class Fields: instructions = PropertyField( name="instructions", description="Any additional instructions to use in determining the dialogue instructions", type="text", - default="" + default="", ) - + def __init__(self, title="Determine Character Dialogue Instructions", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character", socket_type="character") self.add_input("instructions", socket_type="str", optional=True) - + self.set_property("instructions", "") - + self.add_output("dialogue_instructions", socket_type="str") - + async def run(self, state: GraphState): character = self.require_input("character") instructions = self.normalized_input_value("instructions") - - dialogue_instructions = await self.agent.determine_character_dialogue_instructions(character, instructions) - - self.set_output_values({ - "dialogue_instructions": dialogue_instructions - }) + + dialogue_instructions = ( + await self.agent.determine_character_dialogue_instructions( + character, instructions + ) + ) + + self.set_output_values({"dialogue_instructions": dialogue_instructions}) + @register("agents/creator/ContextualGenerate") class ContextualGenerate(AgentNode): """ Generates text based on the given context. - + Inputs: - + - state: The current state of the graph - context_type: The type of context to use in generating the text - context_name: The name of the context to use in generating the text @@ -150,9 +158,9 @@ class ContextualGenerate(AgentNode): - partial: The partial text to use in generating the text - uid: The uid to use in generating the text - generation_options: The generation options to use in generating the text - + Properties: - + - context_type: The type of context to use in generating the text - context_name: The name of the context to use in generating the text - instructions: The instructions to use in generating the text @@ -161,82 +169,82 @@ class ContextualGenerate(AgentNode): - uid: The uid to use in generating the text - context_aware: Whether to use the context in generating the text - history_aware: Whether to use the history in generating the text - + Outputs: - + - state: The updated state of the graph - text: The generated text """ - _agent_name:ClassVar[str] = "creator" - + + _agent_name: ClassVar[str] = "creator" + class Fields: context_type = PropertyField( name="context_type", description="The type of context to use in generating the text", type="str", choices=[ - "character attribute", - "character detail", - "character dialogue", - "scene intro", - "scene intent", - "scene phase intent", + "character attribute", + "character detail", + "character dialogue", + "scene intro", + "scene intent", + "scene phase intent", "scene type description", - "scene type instructions", - "general", - "list", - "scene", + "scene type instructions", + "general", + "list", + "scene", "world context", ], - default="general" + default="general", ) context_name = PropertyField( name="context_name", description="The name of the context to use in generating the text", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) instructions = PropertyField( name="instructions", description="The instructions to use in generating the text", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) length = PropertyField( name="length", description="The length of the text to generate", type="int", - default=100 + default=100, ) character = PropertyField( name="character", description="The character to generate the text for", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) uid = PropertyField( name="uid", description="The uid to use in generating the text", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) context_aware = PropertyField( name="context_aware", description="Whether to use the context in generating the text", type="bool", - default=True + default=True, ) history_aware = PropertyField( name="history_aware", description="Whether to use the history in generating the text", type="bool", - default=True + default=True, ) - - + def __init__(self, title="Contextual Generate", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("context_type", socket_type="str", optional=True) @@ -247,8 +255,10 @@ class ContextualGenerate(AgentNode): self.add_input("original", socket_type="str", optional=True) self.add_input("partial", socket_type="str", optional=True) self.add_input("uid", socket_type="str", optional=True) - self.add_input("generation_options", socket_type="generation_options", optional=True) - + self.add_input( + "generation_options", socket_type="generation_options", optional=True + ) + self.set_property("context_type", "general") self.set_property("context_name", UNRESOLVED) self.set_property("instructions", UNRESOLVED) @@ -257,10 +267,10 @@ class ContextualGenerate(AgentNode): self.set_property("uid", UNRESOLVED) self.set_property("context_aware", True) self.set_property("history_aware", True) - + self.add_output("state") self.add_output("text", socket_type="str") - + async def run(self, state: GraphState): scene = active_scene.get() context_type = self.require_input("context_type") @@ -274,12 +284,12 @@ class ContextualGenerate(AgentNode): generation_options = self.normalized_input_value("generation_options") context_aware = self.normalized_input_value("context_aware") history_aware = self.normalized_input_value("history_aware") - + context = f"{context_type}:{context_name}" if context_name else context_type - + if isinstance(character, scene.Character): character = character.name - + text = await self.agent.contextual_generate_from_args( context=context, instructions=instructions, @@ -288,31 +298,32 @@ class ContextualGenerate(AgentNode): original=original, partial=partial or "", uid=uid, - writing_style=generation_options.writing_style if generation_options else None, + writing_style=generation_options.writing_style + if generation_options + else None, spices=generation_options.spices if generation_options else None, spice_level=generation_options.spice_level if generation_options else 0.0, context_aware=context_aware, - history_aware=history_aware + history_aware=history_aware, ) - - self.set_output_values({ - "state": state, - "text": text - }) - + + self.set_output_values({"state": state, "text": text}) + + @register("agents/creator/GenerateThematicList") class GenerateThematicList(AgentNode): """ Generates a list of thematic items based on the instructions. """ - _agent_name:ClassVar[str] = "creator" - + + _agent_name: ClassVar[str] = "creator" + class Fields: instructions = PropertyField( name="instructions", description="The instructions to use in generating the list", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) iterations = PropertyField( name="iterations", @@ -323,27 +334,24 @@ class GenerateThematicList(AgentNode): min=1, max=10, ) - + def __init__(self, title="Generate Thematic List", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("instructions", socket_type="str", optional=True) self.set_property("instructions", UNRESOLVED) self.set_property("iterations", 1) - + self.add_output("state") self.add_output("list", socket_type="list") - + async def run(self, state: GraphState): instructions = self.normalized_input_value("instructions") iterations = self.require_number_input("iterations") - + list = await self.agent.generate_thematic_list(instructions, iterations) - - self.set_output_values({ - "state": state, - "list": list - }) \ No newline at end of file + + self.set_output_values({"state": state, "list": list}) diff --git a/src/talemate/agents/creator/scenario.py b/src/talemate/agents/creator/scenario.py index ba058971..67fd7708 100644 --- a/src/talemate/agents/creator/scenario.py +++ b/src/talemate/agents/creator/scenario.py @@ -10,7 +10,7 @@ class ScenarioCreatorMixin: @set_processing async def determine_scenario_description(self, text: str): description = await Prompt.request( - f"creator.determine-scenario-description", + "creator.determine-scenario-description", self.client, "analyze_long", vars={ @@ -25,7 +25,7 @@ class ScenarioCreatorMixin: description: str, ): content_context = await Prompt.request( - f"creator.determine-content-context", + "creator.determine-content-context", self.client, "create_short", vars={ diff --git a/src/talemate/agents/director/__init__.py b/src/talemate/agents/director/__init__.py index eb36de75..0da53d74 100644 --- a/src/talemate/agents/director/__init__.py +++ b/src/talemate/agents/director/__init__.py @@ -1,16 +1,12 @@ from __future__ import annotations -import random from typing import TYPE_CHECKING, List import structlog import traceback -import talemate.emit.async_signals import talemate.instance as instance -from talemate.agents.conversation import ConversationAgentEmission from talemate.emit import emit -from talemate.prompts import Prompt from talemate.scene_message import DirectorMessage from talemate.util import random_color from talemate.character import deactivate_character @@ -27,7 +23,7 @@ from .legacy_scene_instructions import LegacySceneInstructionsMixin from .auto_direct import AutoDirectMixin from .websocket_handler import DirectorWebsocketHandler -import talemate.agents.director.nodes +import talemate.agents.director.nodes # noqa: F401 if TYPE_CHECKING: from talemate import Character, Scene @@ -42,7 +38,7 @@ class DirectorAgent( GenerateChoicesMixin, AutoDirectMixin, LegacySceneInstructionsMixin, - Agent + Agent, ): agent_type = "director" verbose_name = "Director" @@ -74,7 +70,7 @@ class DirectorAgent( ], ), }, - ), + ), } MemoryRAGMixin.add_actions(actions) GenerateChoicesMixin.add_actions(actions) @@ -112,7 +108,6 @@ class DirectorAgent( created_characters = [] for character_name in self.scene.world_state.characters.keys(): - if exclude and character_name.lower() in exclude: continue @@ -148,13 +143,12 @@ class DirectorAgent( memory = instance.get_agent("memory") scene: "Scene" = self.scene any_attribute_templates = False - + loading_status = LoadingStatus(max_steps=None, cancellable=True) - + # Start of character creation log.debug("persist_character", name=name) - # Determine the character's name (or clarify if it's already set) if determine_name: loading_status("Determining character name") @@ -162,7 +156,7 @@ class DirectorAgent( log.debug("persist_character", adjusted_name=name) # Create the blank character - character:Character = self.scene.Character(name=name) + character: Character = self.scene.Character(name=name) # Add the character to the scene character.color = random_color() @@ -170,32 +164,44 @@ class DirectorAgent( character=character, agent=instance.get_agent("conversation") ) await self.scene.add_actor(actor) - - try: + try: # Apply any character generation templates if templates: loading_status("Applying character generation templates") - templates = scene.world_state_manager.template_collection.collect_all(templates) + templates = scene.world_state_manager.template_collection.collect_all( + templates + ) log.debug("persist_character", applying_templates=templates) await scene.world_state_manager.apply_templates( - templates.values(), + templates.values(), character_name=character.name, - information=content + information=content, ) - + # if any of the templates are attribute templates, then we no longer need to # generate a character sheet - any_attribute_templates = any(template.template_type == "character_attribute" for template in templates.values()) - log.debug("persist_character", any_attribute_templates=any_attribute_templates) - - if any_attribute_templates and augment_attributes and generate_attributes: - log.debug("persist_character", augmenting_attributes=augment_attributes) + any_attribute_templates = any( + template.template_type == "character_attribute" + for template in templates.values() + ) + log.debug( + "persist_character", any_attribute_templates=any_attribute_templates + ) + + if ( + any_attribute_templates + and augment_attributes + and generate_attributes + ): + log.debug( + "persist_character", augmenting_attributes=augment_attributes + ) loading_status("Augmenting character attributes") additional_attributes = await world_state.extract_character_sheet( name=name, text=content, - augmentation_instructions=augment_attributes + augmentation_instructions=augment_attributes, ) character.base_attributes.update(additional_attributes) @@ -212,25 +218,26 @@ class DirectorAgent( log.debug("persist_character", attributes=attributes) character.base_attributes = attributes - + # Generate a description for the character if not description: loading_status("Generating character description") - description = await creator.determine_character_description(character, information=content) + description = await creator.determine_character_description( + character, information=content + ) character.description = description log.debug("persist_character", description=description) # Generate a dialogue instructions for the character loading_status("Generating acting instructions") - dialogue_instructions = await creator.determine_character_dialogue_instructions( - character, - information=content + dialogue_instructions = ( + await creator.determine_character_dialogue_instructions( + character, information=content + ) ) character.dialogue_instructions = dialogue_instructions - log.debug( - "persist_character", dialogue_instructions=dialogue_instructions - ) - + log.debug("persist_character", dialogue_instructions=dialogue_instructions) + # Narrate the character's entry if the option is selected if active and narrate_entry: loading_status("Narrating character entry") @@ -240,24 +247,26 @@ class DirectorAgent( "narrate_character_entry", emit_message=True, character=character, - narrative_direction=narrate_entry_direction + narrative_direction=narrate_entry_direction, ) - + # Deactivate the character if not active if not active: await deactivate_character(scene, character) - + # Commit the character's details to long term memory await character.commit_to_memory(memory) self.scene.emit_status() self.scene.world_state.emit() - - loading_status.done(message=f"{character.name} added to scene", status="success") + + loading_status.done( + message=f"{character.name} added to scene", status="success" + ) return character except GenerationCancelled: loading_status.done(message="Character creation cancelled", status="idle") await scene.remove_actor(actor) - except Exception as e: + except Exception: loading_status.done(message="Character creation failed", status="error") await scene.remove_actor(actor) log.error("Error persisting character", error=traceback.format_exc()) @@ -276,4 +285,4 @@ class DirectorAgent( def allow_repetition_break( self, kind: str, agent_function_name: str, auto: bool = False ): - return False \ No newline at end of file + return False diff --git a/src/talemate/agents/director/auto_direct.py b/src/talemate/agents/director/auto_direct.py index c1fbe52f..061e9f50 100644 --- a/src/talemate/agents/director/auto_direct.py +++ b/src/talemate/agents/director/auto_direct.py @@ -1,23 +1,22 @@ from typing import TYPE_CHECKING import structlog -import pydantic from talemate.agents.base import ( set_processing, AgentAction, AgentActionConfig, - AgentEmission, - AgentTemplateEmission, ) -from talemate.status import set_loading import talemate.game.focal as focal -from talemate.prompts import Prompt import talemate.emit.async_signals as async_signals -from talemate.scene_message import CharacterMessage, TimePassageMessage, DirectorMessage, NarratorMessage -from talemate.scene.schema import ScenePhase, SceneType, SceneIntent +from talemate.scene_message import ( + CharacterMessage, + TimePassageMessage, + NarratorMessage, +) +from talemate.scene.schema import ScenePhase, SceneType from talemate.scene.intent import set_scene_phase from talemate.world_state.manager import WorldStateManager from talemate.world_state.templates.scene import SceneType as TemplateSceneType -import talemate.agents.director.auto_direct_nodes +import talemate.agents.director.auto_direct_nodes # noqa: F401 from talemate.world_state.templates.base import TypedCollection if TYPE_CHECKING: @@ -26,15 +25,15 @@ if TYPE_CHECKING: log = structlog.get_logger("talemate.agents.conversation.direct") -#talemate.emit.async_signals.register( -#) +# talemate.emit.async_signals.register( +# ) + class AutoDirectMixin: - """ Director agent mixin for automatic scene direction. """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): actions["auto_direct"] = AgentAction( @@ -108,113 +107,117 @@ class AutoDirectMixin: ), }, ) - + # config property helpers - + @property def auto_direct_enabled(self) -> bool: return self.actions["auto_direct"].enabled - + @property def auto_direct_max_auto_turns(self) -> int: return self.actions["auto_direct"].config["max_auto_turns"].value - + @property def auto_direct_max_idle_turns(self) -> int: return self.actions["auto_direct"].config["max_idle_turns"].value - + @property def auto_direct_max_repeat_turns(self) -> int: return self.actions["auto_direct"].config["max_repeat_turns"].value - + @property def auto_direct_instruct_actors(self) -> bool: return self.actions["auto_direct"].config["instruct_actors"].value - + @property def auto_direct_instruct_narrator(self) -> bool: return self.actions["auto_direct"].config["instruct_narrator"].value - + @property def auto_direct_instruct_frequency(self) -> int: return self.actions["auto_direct"].config["instruct_frequency"].value - + @property def auto_direct_evaluate_scene_intention(self) -> int: return self.actions["auto_direct"].config["evaluate_scene_intention"].value - + @property def auto_direct_instruct_any(self) -> bool: """ Will check whether actor or narrator instructions are enabled. - + For narrator instructions to be enabled instruct_narrator needs to be enabled as well. - + Returns: bool: True if either actor or narrator instructions are enabled. """ - + return self.auto_direct_instruct_actors or self.auto_direct_instruct_narrator - - + # signal connect def connect(self, scene): super().connect(scene) async_signals.get("game_loop").connect(self.on_game_loop) - + async def on_game_loop(self, event): if not self.auto_direct_enabled: return - + if self.auto_direct_evaluate_scene_intention > 0: evaluation_due = self.get_scene_state("evaluated_scene_intention", 0) if evaluation_due == 0: await self.auto_direct_set_scene_intent() - self.set_scene_states(evaluated_scene_intention=self.auto_direct_evaluate_scene_intention) + self.set_scene_states( + evaluated_scene_intention=self.auto_direct_evaluate_scene_intention + ) else: self.set_scene_states(evaluated_scene_intention=evaluation_due - 1) # helpers - - def auto_direct_is_due_for_instruction(self, actor_name:str) -> bool: + + def auto_direct_is_due_for_instruction(self, actor_name: str) -> bool: """ Check if the actor is due for instruction. """ - + if self.auto_direct_instruct_frequency == 1: return True - + messages_since_last_instruction = 0 - + def count_messages(message): nonlocal messages_since_last_instruction if message.typ in ["character", "narrator"]: messages_since_last_instruction += 1 - - + last_instruction = self.scene.last_message_of_type( "director", character_name=actor_name, max_iterations=25, on_iterate=count_messages, ) - - log.debug("auto_direct_is_due_for_instruction", messages_since_last_instruction=messages_since_last_instruction, last_instruction=last_instruction.id if last_instruction else None) - + + log.debug( + "auto_direct_is_due_for_instruction", + messages_since_last_instruction=messages_since_last_instruction, + last_instruction=last_instruction.id if last_instruction else None, + ) + if not last_instruction: return True - + return messages_since_last_instruction >= self.auto_direct_instruct_frequency - + def auto_direct_candidates(self) -> list["Character"]: """ Returns a list of characters who are valid candidates to speak next. based on the max_idle_turns, max_repeat_turns, and the most recent character. """ - - scene:"Scene" = self.scene - + + scene: "Scene" = self.scene + most_recent_character = None repeat_count = 0 last_player_turn = None @@ -223,89 +226,105 @@ class AutoDirectMixin: active_charcters = list(scene.characters) active_character_names = [character.name for character in active_charcters] instruct_narrator = self.auto_direct_instruct_narrator - + # if there is only one character then they are the only candidate if len(active_charcters) == 1: return active_charcters - + BACKLOG_LIMIT = 50 - + player_character_active = scene.player_character_exists - + # check the last BACKLOG_LIMIT entries in the scene history and collect into # a dictionary of character names and the number of turns since they last spoke. - + len_history = len(scene.history) - 1 num = 0 for idx in range(len_history, -1, -1): message = scene.history[idx] turns = len_history - idx - + num += 1 - + if num > BACKLOG_LIMIT: break - + if isinstance(message, TimePassageMessage): break - + if not isinstance(message, (CharacterMessage, NarratorMessage)): continue - + # if character message but character is not in the active characters list then skip - if isinstance(message, CharacterMessage) and message.character_name not in active_character_names: + if ( + isinstance(message, CharacterMessage) + and message.character_name not in active_character_names + ): continue - + if isinstance(message, NarratorMessage): if not instruct_narrator: continue character = scene.narrator_character_object else: character = scene.get_character(message.character_name) - + if not character: continue - + if character.is_player and last_player_turn is None: last_player_turn = turns elif not character.is_player and last_player_turn is None: consecutive_auto_turns += 1 - + if not most_recent_character: most_recent_character = character repeat_count += 1 elif character == most_recent_character: repeat_count += 1 - + if character.name not in candidates: candidates[character.name] = turns - + # add any characters that have not spoken yet for character in active_charcters: if character.name not in candidates: candidates[character.name] = 0 - + # explicitly add narrator if enabled and not already in candidates if instruct_narrator and scene.narrator_character_object: narrator = scene.narrator_character_object if narrator.name not in candidates: candidates[narrator.name] = 0 - - log.debug(f"auto_direct_candidates: {candidates}", most_recent_character=most_recent_character, repeat_count=repeat_count, last_player_turn=last_player_turn, consecutive_auto_turns=consecutive_auto_turns) - + + log.debug( + f"auto_direct_candidates: {candidates}", + most_recent_character=most_recent_character, + repeat_count=repeat_count, + last_player_turn=last_player_turn, + consecutive_auto_turns=consecutive_auto_turns, + ) + if not most_recent_character: log.debug("auto_direct_candidates: No most recent character found.") return list(scene.characters) - + # if player has not spoken in a while then they are favored - if player_character_active and consecutive_auto_turns >= self.auto_direct_max_auto_turns: - log.debug("auto_direct_candidates: User controlled character has not spoken in a while.") + if ( + player_character_active + and consecutive_auto_turns >= self.auto_direct_max_auto_turns + ): + log.debug( + "auto_direct_candidates: User controlled character has not spoken in a while." + ) return [scene.get_player_character()] - + # check if most recent character has spoken too many times in a row # if so then remove from candidates if repeat_count >= self.auto_direct_max_repeat_turns: - log.debug("auto_direct_candidates: Most recent character has spoken too many times in a row.", most_recent_character=most_recent_character + log.debug( + "auto_direct_candidates: Most recent character has spoken too many times in a row.", + most_recent_character=most_recent_character, ) candidates.pop(most_recent_character.name, None) @@ -314,27 +333,34 @@ class AutoDirectMixin: favored_candidates = [] for name, turns in candidates.items(): if turns > self.auto_direct_max_idle_turns: - log.debug("auto_direct_candidates: Character has gone too long without speaking.", character_name=name, turns=turns) + log.debug( + "auto_direct_candidates: Character has gone too long without speaking.", + character_name=name, + turns=turns, + ) favored_candidates.append(scene.get_character(name)) - + if favored_candidates: return favored_candidates - - return [scene.get_character(character_name) for character_name in candidates.keys()] + + return [ + scene.get_character(character_name) for character_name in candidates.keys() + ] # actions - + @set_processing - async def auto_direct_set_scene_intent(self, require:bool=False) -> ScenePhase | None: - - async def set_scene_intention(type:str, intention:str) -> ScenePhase: + async def auto_direct_set_scene_intent( + self, require: bool = False + ) -> ScenePhase | None: + async def set_scene_intention(type: str, intention: str) -> ScenePhase: await set_scene_phase(self.scene, type, intention) self.scene.emit_status() return self.scene.intent_state.phase - + async def do_nothing(*args, **kwargs) -> None: return None - + focal_handler = focal.Focal( self.client, callbacks=[ @@ -355,57 +381,63 @@ class AutoDirectMixin: ], max_calls=1, scene=self.scene, - scene_type_ids=", ".join([f'"{scene_type.id}"' for scene_type in self.scene.intent_state.scene_types.values()]), + scene_type_ids=", ".join( + [ + f'"{scene_type.id}"' + for scene_type in self.scene.intent_state.scene_types.values() + ] + ), retries=1, require=require, ) - + await focal_handler.request( "director.direct-determine-scene-intent", ) - + return self.scene.intent_state.phase - + @set_processing async def auto_direct_generate_scene_types( - self, - instructions:str, - max_scene_types:int=1, + self, + instructions: str, + max_scene_types: int = 1, ): - - world_state_manager:WorldStateManager = self.scene.world_state_manager - - scene_type_templates:TypedCollection = await world_state_manager.get_templates(types=["scene_type"]) - - async def add_from_template(id:str) -> SceneType: - template:TemplateSceneType | None = scene_type_templates.find_by_name(id) + world_state_manager: WorldStateManager = self.scene.world_state_manager + + scene_type_templates: TypedCollection = await world_state_manager.get_templates( + types=["scene_type"] + ) + + async def add_from_template(id: str) -> SceneType: + template: TemplateSceneType | None = scene_type_templates.find_by_name(id) if not template: - log.warning("auto_direct_generate_scene_types: Template not found.", name=id) + log.warning( + "auto_direct_generate_scene_types: Template not found.", name=id + ) return None return template.apply_to_scene(self.scene) - + async def generate_scene_type( - id:str = None, - name:str = None, - description:str = None, - instructions:str = None, + id: str = None, + name: str = None, + description: str = None, + instructions: str = None, ) -> SceneType: - if not id or not name: return None - + scene_type = SceneType( id=id, name=name, description=description, instructions=instructions, ) - + self.scene.intent_state.scene_types[id] = scene_type - + return scene_type - - + focal_handler = focal.Focal( self.client, callbacks=[ @@ -435,7 +467,7 @@ class AutoDirectMixin: instructions=instructions, scene_type_templates=scene_type_templates.templates, ) - + await focal_handler.request( "director.generate-scene-types", - ) \ No newline at end of file + ) diff --git a/src/talemate/agents/director/auto_direct_nodes.py b/src/talemate/agents/director/auto_direct_nodes.py index 47fcefed..d3f4fe78 100644 --- a/src/talemate/agents/director/auto_direct_nodes.py +++ b/src/talemate/agents/director/auto_direct_nodes.py @@ -1,17 +1,13 @@ import structlog -from typing import TYPE_CHECKING, ClassVar +from typing import ClassVar from talemate.game.engine.nodes.core import GraphState, PropertyField from talemate.game.engine.nodes.registry import register from talemate.game.engine.nodes.agent import AgentNode from talemate.scene.schema import ScenePhase -from talemate.context import active_scene - -if TYPE_CHECKING: - from talemate.tale_mate import Scene - from talemate.agents.director import DirectorAgent log = structlog.get_logger("talemate.game.engine.nodes.agents.director") + @register("agents/director/auto-direct/Candidates") class AutoDirectCandidates(AgentNode): """ @@ -19,52 +15,51 @@ class AutoDirectCandidates(AgentNode): next action, based on the director's auto-direct settings and the recent scene history. """ - _agent_name:ClassVar[str] = "director" - + + _agent_name: ClassVar[str] = "director" + def __init__(self, title="Auto Direct Candidates", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("characters", socket_type="list") - + async def run(self, state: GraphState): candidates = self.agent.auto_direct_candidates() - self.set_output_values({ - "characters": candidates - }) - - + self.set_output_values({"characters": candidates}) + + @register("agents/director/auto-direct/DetermineSceneIntent") class DetermineSceneIntent(AgentNode): """ Determines the scene intent based on the current scene state. """ - _agent_name:ClassVar[str] = "director" - + + _agent_name: ClassVar[str] = "director" + def __init__(self, title="Determine Scene Intent", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") self.add_output("scene_phase", socket_type="scene_intent/scene_phase") - + async def run(self, state: GraphState): - phase:ScenePhase = await self.agent.auto_direct_set_scene_intent() - - self.set_output_values({ - "state": state, - "scene_phase": phase - }) + phase: ScenePhase = await self.agent.auto_direct_set_scene_intent() + + self.set_output_values({"state": state, "scene_phase": phase}) + @register("agents/director/auto-direct/GenerateSceneTypes") class GenerateSceneTypes(AgentNode): """ Generates scene types based on the current scene state. """ - _agent_name:ClassVar[str] = "director" - + + _agent_name: ClassVar[str] = "director" + class Fields: instructions = PropertyField( name="instructions", @@ -72,17 +67,17 @@ class GenerateSceneTypes(AgentNode): description="The instructions for the scene types", default="", ) - + max_scene_types = PropertyField( name="max_scene_types", type="int", description="The maximum number of scene types to generate", default=1, ) - + def __init__(self, title="Generate Scene Types", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("instructions", socket_type="str", optional=True) @@ -90,29 +85,26 @@ class GenerateSceneTypes(AgentNode): self.set_property("instructions", "") self.set_property("max_scene_types", 1) self.add_output("state") - + async def run(self, state: GraphState): instructions = self.normalized_input_value("instructions") max_scene_types = self.normalized_input_value("max_scene_types") - + scene_types = await self.agent.auto_direct_generate_scene_types( - instructions=instructions, - max_scene_types=max_scene_types + instructions=instructions, max_scene_types=max_scene_types ) - - self.set_output_values({ - "state": state, - "scene_types": scene_types - }) - + + self.set_output_values({"state": state, "scene_types": scene_types}) + @register("agents/director/auto-direct/IsDueForInstruction") class IsDueForInstruction(AgentNode): """ Checks if the actor is due for instruction based on the auto-direct settings. """ - _agent_name:ClassVar[str] = "director" - + + _agent_name: ClassVar[str] = "director" + class Fields: actor_name = PropertyField( name="actor_name", @@ -120,24 +112,21 @@ class IsDueForInstruction(AgentNode): description="The name of the actor to check instruction timing for", default="", ) - + def __init__(self, title="Is Due For Instruction", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("actor_name", socket_type="str") - + self.set_property("actor_name", "") - + self.add_output("is_due", socket_type="bool") self.add_output("actor_name", socket_type="str") - + async def run(self, state: GraphState): actor_name = self.require_input("actor_name") - + is_due = self.agent.auto_direct_is_due_for_instruction(actor_name) - - self.set_output_values({ - "is_due": is_due, - "actor_name": actor_name - }) \ No newline at end of file + + self.set_output_values({"is_due": is_due, "actor_name": actor_name}) diff --git a/src/talemate/agents/director/generate_choices.py b/src/talemate/agents/director/generate_choices.py index 429b84a3..0120f2f3 100644 --- a/src/talemate/agents/director/generate_choices.py +++ b/src/talemate/agents/director/generate_choices.py @@ -1,14 +1,12 @@ from typing import TYPE_CHECKING import random import structlog -from functools import wraps import dataclasses from talemate.agents.base import ( set_processing, AgentAction, AgentActionConfig, AgentTemplateEmission, - DynamicInstruction, ) from talemate.events import GameLoopStartEvent from talemate.scene_message import NarratorMessage, CharacterMessage @@ -24,7 +22,7 @@ __all__ = [ log = structlog.get_logger() talemate.emit.async_signals.register( - "agent.director.generate_choices.before_generate", + "agent.director.generate_choices.before_generate", "agent.director.generate_choices.inject_instructions", "agent.director.generate_choices.generated", ) @@ -32,18 +30,19 @@ talemate.emit.async_signals.register( if TYPE_CHECKING: from talemate.tale_mate import Character + @dataclasses.dataclass class GenerateChoicesEmission(AgentTemplateEmission): character: "Character | None" = None choices: list[str] = dataclasses.field(default_factory=list) + class GenerateChoicesMixin: - """ Director agent mixin that provides functionality for automatically guiding the actors or the narrator during the scene progression. """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): actions["_generate_choices"] = AgentAction( @@ -65,7 +64,6 @@ class GenerateChoicesMixin: max=1, step=0.1, ), - "num_choices": AgentActionConfig( type="number", label="Number of Actions", @@ -75,33 +73,31 @@ class GenerateChoicesMixin: max=10, step=1, ), - "never_auto_progress": AgentActionConfig( type="bool", label="Never Auto Progress on Action Selection", description="If enabled, the scene will not auto progress after you select an action.", value=False, ), - "instructions": AgentActionConfig( type="blob", label="Instructions", description="Provide some instructions to the director for generating actions.", value="", ), - } + }, ) - + # config property helpers - + @property def generate_choices_enabled(self): return self.actions["_generate_choices"].enabled - + @property def generate_choices_chance(self): return self.actions["_generate_choices"].config["chance"].value - + @property def generate_choices_num_choices(self): return self.actions["_generate_choices"].config["num_choices"].value @@ -115,24 +111,25 @@ class GenerateChoicesMixin: return self.actions["_generate_choices"].config["instructions"].value # signal connect - + def connect(self, scene): super().connect(scene) - talemate.emit.async_signals.get("player_turn_start").connect(self.on_player_turn_start) - + talemate.emit.async_signals.get("player_turn_start").connect( + self.on_player_turn_start + ) + async def on_player_turn_start(self, event: GameLoopStartEvent): if not self.enabled: return - + if self.generate_choices_enabled: - # look backwards through history and abort if we encounter # a character message with source "player" before either # a character message with a different source or a narrator message # # this is so choices aren't generated when the player message was # the most recent content in the scene - + for i in range(len(self.scene.history) - 1, -1, -1): message = self.scene.history[i] if isinstance(message, NarratorMessage): @@ -141,12 +138,11 @@ class GenerateChoicesMixin: if message.source == "player": return break - + if random.random() < self.generate_choices_chance: - await self.generate_choices() - + await self.generate_choices() + # methods - @set_processing async def generate_choices( @@ -154,20 +150,23 @@ class GenerateChoicesMixin: instructions: str = None, character: "Character | str | None" = None, ): - emission: GenerateChoicesEmission = GenerateChoicesEmission(agent=self) if isinstance(character, str): character = self.scene.get_character(character) - + if not character: character = self.scene.get_player_character() - + emission.character = character - - await talemate.emit.async_signals.get("agent.director.generate_choices.before_generate").send(emission) - await talemate.emit.async_signals.get("agent.director.generate_choices.inject_instructions").send(emission) - + + await talemate.emit.async_signals.get( + "agent.director.generate_choices.before_generate" + ).send(emission) + await talemate.emit.async_signals.get( + "agent.director.generate_choices.inject_instructions" + ).send(emission) + response = await Prompt.request( "director.generate-choices", self.client, @@ -178,7 +177,9 @@ class GenerateChoicesMixin: "character": character, "num_choices": self.generate_choices_num_choices, "instructions": instructions or self.generate_choices_instructions, - "dynamic_instructions": emission.dynamic_instructions if emission else None, + "dynamic_instructions": emission.dynamic_instructions + if emission + else None, }, ) @@ -187,10 +188,10 @@ class GenerateChoicesMixin: choices = util.extract_list(choice_text) # strip quotes choices = [choice.strip().strip('"') for choice in choices] - + # limit to num_choices - choices = choices[:self.generate_choices_num_choices] - + choices = choices[: self.generate_choices_num_choices] + except Exception as e: log.error("generate_choices failed", error=str(e), response=response) return @@ -198,15 +199,17 @@ class GenerateChoicesMixin: emit( "player_choice", response, - data = { + data={ "choices": choices, "character": character.name, }, - websocket_passthrough=True + websocket_passthrough=True, ) - + emission.response = response emission.choices = choices - await talemate.emit.async_signals.get("agent.director.generate_choices.generated").send(emission) - - return emission.response \ No newline at end of file + await talemate.emit.async_signals.get( + "agent.director.generate_choices.generated" + ).send(emission) + + return emission.response diff --git a/src/talemate/agents/director/guide.py b/src/talemate/agents/director/guide.py index 7b7424c5..34df323f 100644 --- a/src/talemate/agents/director/guide.py +++ b/src/talemate/agents/director/guide.py @@ -17,12 +17,11 @@ from talemate.util import strip_partial_sentences if TYPE_CHECKING: from talemate.tale_mate import Character from talemate.agents.summarize.analyze_scene import SceneAnalysisEmission - from talemate.agents.editor.revision import RevisionAnalysisEmission log = structlog.get_logger() talemate.emit.async_signals.register( - "agent.director.guide.before_generate", + "agent.director.guide.before_generate", "agent.director.guide.inject_instructions", "agent.director.guide.generated", ) @@ -32,6 +31,7 @@ talemate.emit.async_signals.register( class DirectorGuidanceEmission(AgentTemplateEmission): pass + def set_processing(fn): """ Custom decorator that emits the agent status as processing while the function @@ -42,29 +42,33 @@ def set_processing(fn): @wraps(fn) async def wrapper(self, *args, **kwargs): emission: DirectorGuidanceEmission = DirectorGuidanceEmission(agent=self) - - await talemate.emit.async_signals.get("agent.director.guide.before_generate").send(emission) - await talemate.emit.async_signals.get("agent.director.guide.inject_instructions").send(emission) - + + await talemate.emit.async_signals.get( + "agent.director.guide.before_generate" + ).send(emission) + await talemate.emit.async_signals.get( + "agent.director.guide.inject_instructions" + ).send(emission) + agent_context = active_agent.get() agent_context.state["dynamic_instructions"] = emission.dynamic_instructions - + response = await fn(self, *args, **kwargs) emission.response = response - await talemate.emit.async_signals.get("agent.director.guide.generated").send(emission) + await talemate.emit.async_signals.get("agent.director.guide.generated").send( + emission + ) return emission.response return wrapper - class GuideSceneMixin: - """ Director agent mixin that provides functionality for automatically guiding the actors or the narrator during the scene progression. """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): actions["guide_scene"] = AgentAction( @@ -81,13 +85,13 @@ class GuideSceneMixin: type="bool", label="Guide actors", description="Guide the actors in the scene. This happens during every actor turn.", - value=True + value=True, ), "guide_narrator": AgentActionConfig( type="bool", label="Guide narrator", description="Guide the narrator during the scene. This happens during the narrator's turn.", - value=True + value=True, ), "guidance_length": AgentActionConfig( type="text", @@ -101,7 +105,7 @@ class GuideSceneMixin: {"label": "Medium (512)", "value": "512"}, {"label": "Medium Long (768)", "value": "768"}, {"label": "Long (1024)", "value": "1024"}, - ] + ], ), "cache_guidance": AgentActionConfig( type="bool", @@ -109,57 +113,57 @@ class GuideSceneMixin: description="Will not regenerate the guidance until the scene moves forward or the analysis changes.", value=False, quick_toggle=True, - ) - } + ), + }, ) - + # config property helpers - + @property def guide_scene(self) -> bool: return self.actions["guide_scene"].enabled - + @property def guide_actors(self) -> bool: return self.actions["guide_scene"].config["guide_actors"].value - + @property def guide_narrator(self) -> bool: return self.actions["guide_scene"].config["guide_narrator"].value - + @property def guide_scene_guidance_length(self) -> int: return int(self.actions["guide_scene"].config["guidance_length"].value) - + @property def guide_scene_cache_guidance(self) -> bool: return self.actions["guide_scene"].config["cache_guidance"].value - + # signal connect - + def connect(self, scene): super().connect(scene) - talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").connect( - self.on_summarization_scene_analysis_after - ) - talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").connect( - self.on_summarization_scene_analysis_after - ) - talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect( - self.on_editor_revision_analysis_before - ) - - async def on_summarization_scene_analysis_after(self, emission: "SceneAnalysisEmission"): - + talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.after" + ).connect(self.on_summarization_scene_analysis_after) + talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.cached" + ).connect(self.on_summarization_scene_analysis_after) + talemate.emit.async_signals.get( + "agent.editor.revision-analysis.before" + ).connect(self.on_editor_revision_analysis_before) + + async def on_summarization_scene_analysis_after( + self, emission: "SceneAnalysisEmission" + ): if not self.guide_scene: return - + guidance = None - + cached_guidance = await self.get_cached_guidance(emission.response) - + if emission.analysis_type == "narration" and self.guide_narrator: - if cached_guidance: guidance = cached_guidance else: @@ -167,15 +171,14 @@ class GuideSceneMixin: emission.response, response_length=self.guide_scene_guidance_length, ) - + if not guidance: log.warning("director.guide_scene.narration: Empty resonse") return - + self.set_context_states(narrator_guidance=guidance) - - elif emission.analysis_type == "conversation" and self.guide_actors: - + + elif emission.analysis_type == "conversation" and self.guide_actors: if cached_guidance: guidance = cached_guidance else: @@ -184,94 +187,110 @@ class GuideSceneMixin: emission.template_vars.get("character"), response_length=self.guide_scene_guidance_length, ) - + if not guidance: log.warning("director.guide_scene.conversation: Empty resonse") return - + self.set_context_states(actor_guidance=guidance) if guidance: await self.set_cached_guidance( - emission.response, - guidance, - emission.analysis_type, - emission.template_vars.get("character") + emission.response, + guidance, + emission.analysis_type, + emission.template_vars.get("character"), ) - + async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission): cached_guidance = await self.get_cached_guidance(emission.response) if cached_guidance: - emission.dynamic_instructions.append(DynamicInstruction( - title="Guidance", - content=cached_guidance - )) - + emission.dynamic_instructions.append( + DynamicInstruction(title="Guidance", content=cached_guidance) + ) + # helpers - + def _cache_key(self) -> str: - return f"cached_guidance" - - async def get_cached_guidance(self, analysis:str | None = None) -> str | None: + return "cached_guidance" + + async def get_cached_guidance(self, analysis: str | None = None) -> str | None: """ Returns the cached guidance for the given analysis. - + If analysis is not provided, it will return the cached guidance for the last analysis regardless of the fingerprint. """ - + if not self.guide_scene_cache_guidance: return None - + key = self._cache_key() cached_guidance = self.get_scene_state(key) - + if cached_guidance: if not analysis: return cached_guidance.get("guidance") elif cached_guidance.get("fp") == self.context_fingerpint(extra=[analysis]): return cached_guidance.get("guidance") - + return None - - async def set_cached_guidance(self, analysis:str, guidance: str, analysis_type: str, character: "Character | None" = None): + + async def set_cached_guidance( + self, + analysis: str, + guidance: str, + analysis_type: str, + character: "Character | None" = None, + ): """ Sets the cached guidance for the given analysis. """ key = self._cache_key() - self.set_scene_states(**{ - key: { - "fp": self.context_fingerpint(extra=[analysis]), - "guidance": guidance, - "analysis_type": analysis_type, - "character": character.name if character else None, + self.set_scene_states( + **{ + key: { + "fp": self.context_fingerpint(extra=[analysis]), + "guidance": guidance, + "analysis_type": analysis_type, + "character": character.name if character else None, + } } - }) - + ) + async def get_cached_character_guidance(self, character_name: str) -> str | None: """ Returns the cached guidance for the given character. """ key = self._cache_key() cached_guidance = self.get_scene_state(key) - + if not cached_guidance: return None - - if cached_guidance.get("character") == character_name and cached_guidance.get("analysis_type") == "conversation": + + if ( + cached_guidance.get("character") == character_name + and cached_guidance.get("analysis_type") == "conversation" + ): return cached_guidance.get("guidance") - + return None - + # methods - + @set_processing - async def guide_actor_off_of_scene_analysis(self, analysis: str, character: "Character", response_length: int = 256): + async def guide_actor_off_of_scene_analysis( + self, analysis: str, character: "Character", response_length: int = 256 + ): """ Guides the actor based on the scene analysis. """ - - log.debug("director.guide_actor_off_of_scene_analysis", analysis=analysis, character=character) + + log.debug( + "director.guide_actor_off_of_scene_analysis", + analysis=analysis, + character=character, + ) response = await Prompt.request( "director.guide-conversation", self.client, @@ -285,19 +304,17 @@ class GuideSceneMixin: }, ) return strip_partial_sentences(response).strip() - + @set_processing async def guide_narrator_off_of_scene_analysis( - self, - analysis: str, - response_length: int = 256 + self, analysis: str, response_length: int = 256 ): """ Guides the narrator based on the scene analysis. """ - + log.debug("director.guide_narrator_off_of_scene_analysis", analysis=analysis) - + response = await Prompt.request( "director.guide-narration", self.client, @@ -309,4 +326,4 @@ class GuideSceneMixin: "max_tokens": self.client.max_token_length, }, ) - return strip_partial_sentences(response).strip() \ No newline at end of file + return strip_partial_sentences(response).strip() diff --git a/src/talemate/agents/director/legacy_scene_instructions.py b/src/talemate/agents/director/legacy_scene_instructions.py index a42bb796..52bc606c 100644 --- a/src/talemate/agents/director/legacy_scene_instructions.py +++ b/src/talemate/agents/director/legacy_scene_instructions.py @@ -1,4 +1,3 @@ -from typing import TYPE_CHECKING import structlog from talemate.agents.base import ( set_processing, @@ -9,24 +8,27 @@ from talemate.events import GameLoopActorIterEvent, SceneStateEvent log = structlog.get_logger("talemate.agents.conversation.legacy_scene_instructions") + class LegacySceneInstructionsMixin( GameInstructionsMixin, ): """ Legacy support for scoped api instructions in scenes. - + This is being replaced by node based in structions, but kept for backwards compatibility. - + THIS WILL BE DEPRECATED IN THE FUTURE. """ - + # signal connect - + def connect(self, scene): super().connect(scene) - talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.LSI_on_player_dialog) + talemate.emit.async_signals.get("game_loop_actor_iter").connect( + self.LSI_on_player_dialog + ) talemate.emit.async_signals.get("scene_init").connect(self.LSI_on_scene_init) - + async def LSI_on_scene_init(self, event: SceneStateEvent): """ LEGACY: If game state instructions specify to be run at the start of the game loop @@ -57,8 +59,10 @@ class LegacySceneInstructionsMixin( if not event.actor.character.is_player: return - - log.warning(f"LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.") + + log.warning( + "LSI_on_player_dialog is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future." + ) if event.game_loop.had_passive_narration: log.debug( @@ -77,17 +81,17 @@ class LegacySceneInstructionsMixin( not self.scene.npc_character_names or self.scene.game_state.ops.always_direct ) - - log.warning(f"LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.", always_direct=always_direct) + + log.warning( + "LSI_direct is being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.", + always_direct=always_direct, + ) next_direct = self.next_direct_scene TURNS = 5 - if ( - next_direct % TURNS != 0 - or next_direct == 0 - ): + if next_direct % TURNS != 0 or next_direct == 0: if not always_direct: log.info("direct", skip=True, next_direct=next_direct) self.next_direct_scene += 1 @@ -112,8 +116,10 @@ class LegacySceneInstructionsMixin( async def LSI_direct_scene(self): """ LEGACY: Direct the scene based scoped api scene instructions. - This is being replaced by node based instructions, but kept for + This is being replaced by node based instructions, but kept for backwards compatibility. """ - log.warning(f"Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future.") - await self.run_scene_instructions(self.scene) \ No newline at end of file + log.warning( + "Direct python scene instructions are being DEPRECATED. Please use the new node based instructions. Support for this will be removed in the future." + ) + await self.run_scene_instructions(self.scene) diff --git a/src/talemate/agents/director/nodes.py b/src/talemate/agents/director/nodes.py index e97ec9ab..569826e6 100644 --- a/src/talemate/agents/director/nodes.py +++ b/src/talemate/agents/director/nodes.py @@ -1,29 +1,34 @@ import structlog -from typing import ClassVar, TYPE_CHECKING -from talemate.context import active_scene -from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES +from typing import ClassVar +from talemate.game.engine.nodes.core import ( + GraphState, + PropertyField, + TYPE_CHOICES, +) from talemate.game.engine.nodes.registry import register from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode -if TYPE_CHECKING: - from talemate.tale_mate import Scene -TYPE_CHOICES.extend([ - "director/direction", -]) +TYPE_CHOICES.extend( + [ + "director/direction", + ] +) log = structlog.get_logger("talemate.game.engine.nodes.agents.director") + @register("agents/director/Settings") class DirectorSettings(AgentSettingsNode): """ Base node to render director agent settings. """ - _agent_name:ClassVar[str] = "director" - + + _agent_name: ClassVar[str] = "director" + def __init__(self, title="Director Settings", **kwargs): super().__init__(title=title, **kwargs) - + @register("agents/director/PersistCharacter") class PersistCharacter(AgentNode): @@ -31,45 +36,44 @@ class PersistCharacter(AgentNode): Persists a character that currently only exists as part of the given context as a real character that can actively participate in the scene. """ - _agent_name:ClassVar[str] = "director" - + + _agent_name: ClassVar[str] = "director" + class Fields: determine_name = PropertyField( name="determine_name", type="bool", description="Whether to determine the name of the character", - default=True + default=True, ) - + def __init__(self, title="Persist Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character_name", socket_type="str") self.add_input("context", socket_type="str", optional=True) self.add_input("attributes", socket_type="dict,str", optional=True) - + self.set_property("determine_name", True) - + self.add_output("state") self.add_output("character", socket_type="character") - + async def run(self, state: GraphState): - character_name = self.get_input_value("character_name") context = self.normalized_input_value("context") attributes = self.normalized_input_value("attributes") determine_name = self.normalized_input_value("determine_name") - + character = await self.agent.persist_character( name=character_name, content=context, - attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()]) if attributes else None, - determine_name=determine_name + attributes="\n".join([f"{k}: {v}" for k, v in attributes.items()]) + if attributes + else None, + determine_name=determine_name, ) - - self.set_output_values({ - "state": state, - "character": character - }) \ No newline at end of file + + self.set_output_values({"state": state, "character": character}) diff --git a/src/talemate/agents/director/websocket_handler.py b/src/talemate/agents/director/websocket_handler.py index 4561509d..b8617a62 100644 --- a/src/talemate/agents/director/websocket_handler.py +++ b/src/talemate/agents/director/websocket_handler.py @@ -1,7 +1,6 @@ import pydantic import asyncio import structlog -import traceback from typing import TYPE_CHECKING from talemate.instance import get_agent @@ -18,27 +17,31 @@ __all__ = [ log = structlog.get_logger("talemate.server.director") + class InstructionPayload(pydantic.BaseModel): - instructions:str = "" + instructions: str = "" + class SelectChoicePayload(pydantic.BaseModel): choice: str - character:str = "" + character: str = "" + class CharacterPayload(InstructionPayload): - character:str = "" + character: str = "" + class PersistCharacterPayload(pydantic.BaseModel): name: str templates: list[str] | None = None narrate_entry: bool = True narrate_entry_direction: str = "" - + active: bool = True determine_name: bool = True augment_attributes: str = "" generate_attributes: bool = True - + content: str = "" description: str = "" @@ -47,13 +50,13 @@ class DirectorWebsocketHandler(Plugin): """ Handles director actions """ - + router = "director" - + @property def director(self): return get_agent("director") - + @set_loading("Generating dynamic actions", cancellable=True, as_async=True) async def handle_request_dynamic_choices(self, data: dict): """ @@ -61,21 +64,21 @@ class DirectorWebsocketHandler(Plugin): """ payload = CharacterPayload(**data) await self.director.generate_choices(**payload.model_dump()) - + async def handle_select_choice(self, data: dict): payload = SelectChoicePayload(**data) - + log.debug("selecting choice", payload=payload) - + if payload.character: character = self.scene.get_character(payload.character) else: character = self.scene.get_player_character() - + if not character: log.error("handle_select_choice: could not find character", payload=payload) return - + # hijack the interaction state try: interaction_state = interaction.get() @@ -83,20 +86,23 @@ class DirectorWebsocketHandler(Plugin): # no interaction state log.error("handle_select_choice: no interaction state", payload=payload) return - + interaction_state.from_choice = payload.choice interaction_state.act_as = character.name if not character.is_player else None interaction_state.input = f"@{payload.choice}" - + async def handle_persist_character(self, data: dict): payload = PersistCharacterPayload(**data) scene: "Scene" = self.scene - + if not payload.content: payload.content = scene.snapshot(lines=15) - + # add as asyncio task - task = asyncio.create_task(self.director.persist_character(**payload.model_dump())) + task = asyncio.create_task( + self.director.persist_character(**payload.model_dump()) + ) + async def handle_task_done(task): if task.exception(): log.error("Error persisting character", error=task.exception()) @@ -112,4 +118,3 @@ class DirectorWebsocketHandler(Plugin): await self.signal_operation_done() task.add_done_callback(lambda task: asyncio.create_task(handle_task_done(task))) - diff --git a/src/talemate/agents/editor/__init__.py b/src/talemate/agents/editor/__init__.py index c007af48..f86c9f27 100644 --- a/src/talemate/agents/editor/__init__.py +++ b/src/talemate/agents/editor/__init__.py @@ -40,7 +40,7 @@ class EditorAgent( agent_type = "editor" verbose_name = "Editor" websocket_handler = EditorWebsocketHandler - + @classmethod def init_actions(cls) -> dict[str, AgentAction]: actions = { @@ -56,9 +56,9 @@ class EditorAgent( description="The formatting to use for exposition.", value="novel", choices=[ - {"label": "Chat RP: \"Speech\" *narration*", "value": "chat"}, - {"label": "Novel: \"Speech\" narration", "value": "novel"}, - ] + {"label": 'Chat RP: "Speech" *narration*', "value": "chat"}, + {"label": 'Novel: "Speech" narration', "value": "novel"}, + ], ), "narrator": AgentActionConfig( type="bool", @@ -81,7 +81,7 @@ class EditorAgent( description="Attempt to add extra detail and exposition to the dialogue. Runs automatically after each AI dialogue.", ), } - + MemoryRAGMixin.add_actions(actions) RevisionMixin.add_actions(actions) return actions @@ -90,7 +90,7 @@ class EditorAgent( self.client = client self.is_enabled = True self.actions = EditorAgent.init_actions() - + @property def enabled(self): return self.is_enabled @@ -102,11 +102,11 @@ class EditorAgent( @property def experimental(self): return True - + @property def fix_exposition_enabled(self): return self.actions["fix_exposition"].enabled - + @property def fix_exposition_formatting(self): return self.actions["fix_exposition"].config["formatting"].value @@ -114,11 +114,10 @@ class EditorAgent( @property def fix_exposition_narrator(self): return self.actions["fix_exposition"].config["narrator"].value - + @property def fix_exposition_user_input(self): return self.actions["fix_exposition"].config["user_input"].value - def connect(self, scene): super().connect(scene) @@ -134,7 +133,7 @@ class EditorAgent( formatting = "md" else: formatting = None - + if self.fix_exposition_formatting == "chat": text = text.replace("**", "*") text = text.replace("[", "*").replace("]", "*") @@ -143,15 +142,14 @@ class EditorAgent( text = text.replace("*", "") text = text.replace("[", "").replace("]", "") text = text.replace("(", "").replace(")", "") - - cleaned = util.ensure_dialog_format( - text, - talking_character=character.name if character else None, - formatting=formatting - ) - - return cleaned + cleaned = util.ensure_dialog_format( + text, + talking_character=character.name if character else None, + formatting=formatting, + ) + + return cleaned async def on_conversation_generated(self, emission: ConversationAgentEmission): """ @@ -181,11 +179,13 @@ class EditorAgent( emission.response = edit @set_processing - async def cleanup_character_message(self, content: str, character: Character, force: bool = False): + async def cleanup_character_message( + self, content: str, character: Character, force: bool = False + ): """ Edits a text to make sure all narrative exposition and emotes is encased in * """ - + # if not content was generated, return it as is if not content: return content @@ -210,14 +210,14 @@ class EditorAgent( content = util.clean_dialogue(content, main_name=character.name) content = util.strip_partial_sentences(content) - + # if there are uneven quotation marks, fix them by adding a closing quote if '"' in content and content.count('"') % 2 != 0: content += '"' - + if not self.fix_exposition_enabled and not exposition_fixed: return content - + content = self.fix_exposition_in_text(content, character) return content @@ -225,7 +225,7 @@ class EditorAgent( @set_processing async def clean_up_narration(self, content: str, force: bool = False): content = util.strip_partial_sentences(content) - if (self.fix_exposition_enabled and self.fix_exposition_narrator or force): + if self.fix_exposition_enabled and self.fix_exposition_narrator or force: content = self.fix_exposition_in_text(content, None) if self.fix_exposition_formatting == "chat": if '"' not in content and "*" not in content: @@ -234,25 +234,27 @@ class EditorAgent( return content @set_processing - async def cleanup_user_input(self, text: str, as_narration: bool = False, force: bool = False): - + async def cleanup_user_input( + self, text: str, as_narration: bool = False, force: bool = False + ): # special prefix characters - when found, never edit PREFIX_CHARACTERS = ("!", "@", "/") if text.startswith(PREFIX_CHARACTERS): return text - - if (not self.fix_exposition_user_input or not self.fix_exposition_enabled) and not force: + + if ( + not self.fix_exposition_user_input or not self.fix_exposition_enabled + ) and not force: return text - + if not as_narration: if self.fix_exposition_formatting == "chat": if '"' not in text and "*" not in text: text = f'"{text}"' else: return await self.clean_up_narration(text) - + return self.fix_exposition_in_text(text) - @set_processing async def add_detail(self, content: str, character: Character): @@ -279,4 +281,4 @@ class EditorAgent( response = util.clean_dialogue(response, main_name=character.name) response = util.strip_partial_sentences(response) - return response \ No newline at end of file + return response diff --git a/src/talemate/agents/editor/nodes.py b/src/talemate/agents/editor/nodes.py index 27dc2b8a..efa08387 100644 --- a/src/talemate/agents/editor/nodes.py +++ b/src/talemate/agents/editor/nodes.py @@ -1,70 +1,37 @@ import structlog from typing import ClassVar, TYPE_CHECKING -from talemate.context import active_scene -from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES +from talemate.game.engine.nodes.core import ( + GraphState, + PropertyField, +) from talemate.game.engine.nodes.registry import register from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode if TYPE_CHECKING: - from talemate.tale_mate import Scene from talemate.agents.editor import EditorAgent log = structlog.get_logger("talemate.game.engine.nodes.agents.editor") + @register("agents/editor/Settings") class EditorSettings(AgentSettingsNode): """ Base node to render editor agent settings. """ - _agent_name:ClassVar[str] = "editor" - + + _agent_name: ClassVar[str] = "editor" + def __init__(self, title="Editor Settings", **kwargs): super().__init__(title=title, **kwargs) - + @register("agents/editor/CleanUpUserInput") class CleanUpUserInput(AgentNode): """ Cleans up user input. """ - _agent_name:ClassVar[str] = "editor" - - class Fields: - force = PropertyField( - name="force", - description="Force clean up", - type="bool", - default=False, - ) - - - def __init__(self, title="Clean Up User Input", **kwargs): - super().__init__(title=title, **kwargs) - - def setup(self): - self.add_input("user_input", socket_type="str") - self.add_input("as_narration", socket_type="bool") - - self.set_property("force", False) - - self.add_output("cleaned_user_input", socket_type="str") - - async def run(self, state: GraphState): - editor:"EditorAgent" = self.agent - user_input = self.get_input_value("user_input") - force = self.get_property("force") - as_narration = self.get_input_value("as_narration") - cleaned_user_input = await editor.cleanup_user_input(user_input, as_narration=as_narration, force=force) - self.set_output_values({ - "cleaned_user_input": cleaned_user_input, - }) - -@register("agents/editor/CleanUpNarration") -class CleanUpNarration(AgentNode): - """ - Cleans up narration. - """ - _agent_name:ClassVar[str] = "editor" + + _agent_name: ClassVar[str] = "editor" class Fields: force = PropertyField( @@ -73,31 +40,41 @@ class CleanUpNarration(AgentNode): type="bool", default=False, ) - - def __init__(self, title="Clean Up Narration", **kwargs): + + def __init__(self, title="Clean Up User Input", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): - self.add_input("narration", socket_type="str") + self.add_input("user_input", socket_type="str") + self.add_input("as_narration", socket_type="bool") + self.set_property("force", False) - self.add_output("cleaned_narration", socket_type="str") - + + self.add_output("cleaned_user_input", socket_type="str") + async def run(self, state: GraphState): - editor:"EditorAgent" = self.agent - narration = self.get_input_value("narration") + editor: "EditorAgent" = self.agent + user_input = self.get_input_value("user_input") force = self.get_property("force") - cleaned_narration = await editor.cleanup_narration(narration, force=force) - self.set_output_values({ - "cleaned_narration": cleaned_narration, - }) - -@register("agents/editor/CleanUoCharacterMessage") -class CleanUpCharacterMessage(AgentNode): + as_narration = self.get_input_value("as_narration") + cleaned_user_input = await editor.cleanup_user_input( + user_input, as_narration=as_narration, force=force + ) + self.set_output_values( + { + "cleaned_user_input": cleaned_user_input, + } + ) + + +@register("agents/editor/CleanUpNarration") +class CleanUpNarration(AgentNode): """ - Cleans up character message. + Cleans up narration. """ - _agent_name:ClassVar[str] = "editor" - + + _agent_name: ClassVar[str] = "editor" + class Fields: force = PropertyField( name="force", @@ -105,22 +82,62 @@ class CleanUpCharacterMessage(AgentNode): type="bool", default=False, ) - + + def __init__(self, title="Clean Up Narration", **kwargs): + super().__init__(title=title, **kwargs) + + def setup(self): + self.add_input("narration", socket_type="str") + self.set_property("force", False) + self.add_output("cleaned_narration", socket_type="str") + + async def run(self, state: GraphState): + editor: "EditorAgent" = self.agent + narration = self.get_input_value("narration") + force = self.get_property("force") + cleaned_narration = await editor.cleanup_narration(narration, force=force) + self.set_output_values( + { + "cleaned_narration": cleaned_narration, + } + ) + + +@register("agents/editor/CleanUoCharacterMessage") +class CleanUpCharacterMessage(AgentNode): + """ + Cleans up character message. + """ + + _agent_name: ClassVar[str] = "editor" + + class Fields: + force = PropertyField( + name="force", + description="Force clean up", + type="bool", + default=False, + ) + def __init__(self, title="Clean Up Character Message", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("text", socket_type="str") self.add_input("character", socket_type="character") self.set_property("force", False) self.add_output("cleaned_character_message", socket_type="str") - + async def run(self, state: GraphState): - editor:"EditorAgent" = self.agent + editor: "EditorAgent" = self.agent text = self.get_input_value("text") force = self.get_property("force") character = self.get_input_value("character") - cleaned_character_message = await editor.cleanup_character_message(text, character, force=force) - self.set_output_values({ - "cleaned_character_message": cleaned_character_message, - }) \ No newline at end of file + cleaned_character_message = await editor.cleanup_character_message( + text, character, force=force + ) + self.set_output_values( + { + "cleaned_character_message": cleaned_character_message, + } + ) diff --git a/src/talemate/agents/editor/revision.py b/src/talemate/agents/editor/revision.py index db95403d..d7c87a14 100644 --- a/src/talemate/agents/editor/revision.py +++ b/src/talemate/agents/editor/revision.py @@ -30,12 +30,11 @@ from talemate.agents.summarize import SummarizeEmission from talemate.agents.summarize.layered_history import LayeredHistoryFinalizeEmission from talemate.scene_message import CharacterMessage from talemate.util.dedupe import ( - dedupe_sentences, - SimilarityMatch, - compile_text_to_sentences, - split_sentences_on_comma, + SimilarityMatch, + compile_text_to_sentences, + split_sentences_on_comma, dedupe_sentences_from_matches, - similarity_matches + similarity_matches, ) from talemate.util.diff import dmp_inline_diff from talemate.util import count_tokens @@ -80,12 +79,15 @@ automatic_revision_condition = AgentActionConditional( ## CONTEXT + class RevisionContextState(pydantic.BaseModel): message_id: int | None = None + revision_disabled_context = ContextVar("revision_disabled", default=False) revision_context = ContextVar("revision_context", default=RevisionContextState()) + class RevisionDisabled: def __enter__(self): self.token = revision_disabled_context.set(True) @@ -93,12 +95,15 @@ class RevisionDisabled: def __exit__(self, exc_type, exc_value, traceback): revision_disabled_context.reset(self.token) + class RevisionContext: def __init__(self, message_id: int | None = None): self.message_id = message_id def __enter__(self): - self.token = revision_context.set(RevisionContextState(message_id=self.message_id)) + self.token = revision_context.set( + RevisionContextState(message_id=self.message_id) + ) def __exit__(self, exc_type, exc_value, traceback): revision_context.reset(self.token) @@ -106,34 +111,39 @@ class RevisionContext: ## SCHEMAS + class Issues(pydantic.BaseModel): repetition: list[dict] = pydantic.Field(default_factory=list) repetition_matches: list[SimilarityMatch] = pydantic.Field(default_factory=list) bad_prose: list[PhraseDetection] = pydantic.Field(default_factory=list) repetition_log: list[str] = pydantic.Field(default_factory=list) bad_prose_log: list[str] = pydantic.Field(default_factory=list) - + @property def log(self) -> list[str]: return self.repetition_log + self.bad_prose_log + class RevisionInformation(pydantic.BaseModel): text: str | None = None revision_method: Literal["dedupe", "rewrite", "unslop"] | None = None character: object = None context_type: str | None = None context_name: str | None = None - loading_status: LoadingStatus | None = pydantic.Field(default_factory=LoadingStatus, exclude=True) + loading_status: LoadingStatus | None = pydantic.Field( + default_factory=LoadingStatus, exclude=True + ) summarization_history: list[str] | None = None - + class Config: arbitrary_types_allowed = True + CONTEXTUAL_GENERATION_TYPES = [ "character attribute", "character detail", - #"scene intent", - #"scene phase intent", + # "scene intent", + # "scene phase intent", "world context", "scene intro", ] @@ -147,23 +157,25 @@ async_signals.register( "agent.editor.revision-revise.after", ) + @dataclasses.dataclass class RevisionEmission(AgentTemplateEmission): """ Emission for the revision agent """ - + info: RevisionInformation = dataclasses.field(default_factory=RevisionInformation) issues: Issues = dataclasses.field(default_factory=Issues) + ## MIXIN + class RevisionMixin: - """ Editor agent mixin that handles editing of dialogue and narration based on criteria and instructions """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): actions["revision"] = AgentAction( @@ -188,31 +200,37 @@ class RevisionMixin: condition=automatic_revision_condition, description="Which types of messages to automatically revise.", value=["character", "narrator"], - value_migration=lambda v: ["character", "narrator"] if v is True else [] if v is False else v, - choices=sorted([ - { - "label": "Character Messages", - "value": "character", - "help": "Automatically revise actor actions.", - }, - { - "label": "Narration Messages", - "value": "narrator", - "help": "Automatically revise narrator actions.", - }, - { - "label": "Contextual generation", - "value": "contextual_generation", - "help": "Automatically revise generated context (character attributes, details, etc).", - }, - { - "label": "Summarization", - "value": "summarization", - "help": "Automatically revise summarization.", - } - ], key=lambda x: x["label"]) + value_migration=lambda v: ["character", "narrator"] + if v is True + else [] + if v is False + else v, + choices=sorted( + [ + { + "label": "Character Messages", + "value": "character", + "help": "Automatically revise actor actions.", + }, + { + "label": "Narration Messages", + "value": "narrator", + "help": "Automatically revise narrator actions.", + }, + { + "label": "Contextual generation", + "value": "contextual_generation", + "help": "Automatically revise generated context (character attributes, details, etc).", + }, + { + "label": "Summarization", + "value": "summarization", + "help": "Automatically revise summarization.", + }, + ], + key=lambda x: x["label"], + ), ), - "revision_method": AgentActionConfig( type="text", label="Revision method", @@ -226,17 +244,17 @@ class RevisionMixin: note_on_value={ "dedupe": AgentActionNote( type="primary", - text="This will attempt to dedupe the text if repetition is detected. Will remove content without substituting it, so may cause sentence structure or logic issues." + text="This will attempt to dedupe the text if repetition is detected. Will remove content without substituting it, so may cause sentence structure or logic issues.", ), "unslop": AgentActionNote( type="primary", - text="This calls 1 additional prompt after a generation and will attempt to remove repetition, purple prose, unnatural dialogue, and over-description. May cause details to be lost." + text="This calls 1 additional prompt after a generation and will attempt to remove repetition, purple prose, unnatural dialogue, and over-description. May cause details to be lost.", ), "rewrite": AgentActionNote( type="primary", - text="Each generation will be checked for repetition and unwanted prose. If issues are found, a rewrite of the problematic part(s) will be attempted. (+2 prompts)" - ) - } + text="Each generation will be checked for repetition and unwanted prose. If issues are found, a rewrite of the problematic part(s) will be attempted. (+2 prompts)", + ), + }, ), "split_on_comma": AgentActionConfig( title="Preferences for rewriting", @@ -281,17 +299,20 @@ class RevisionMixin: description="The method to use to detect repetition", value="semantic_similarity", choices=[ - # fuzzy matching (not ai assisted) - # semantic similarity (ai assisted, using memory agent embedding function) - {"label": "Fuzzy matching", "value": "fuzzy"}, - {"label": "Semantic similarity (embeddings)", "value": "semantic_similarity"}, + # fuzzy matching (not ai assisted) + # semantic similarity (ai assisted, using memory agent embedding function) + {"label": "Fuzzy matching", "value": "fuzzy"}, + { + "label": "Semantic similarity (embeddings)", + "value": "semantic_similarity", + }, ], note_on_value={ "semantic_similarity": AgentActionNote( type="warning", - text="Uses the memory agent's embedding function to compare the text. Will use batching when available, but has the potential to do A LOT of calls to the embedding model." + text="Uses the memory agent's embedding function to compare the text. Will use batching when available, but has the potential to do A LOT of calls to the embedding model.", ) - } + }, ), "repetition_threshold": AgentActionConfig( type="number", @@ -320,61 +341,61 @@ class RevisionMixin: max=100, step=1, ), - } + }, ) - + # config property helpers - + @property def revision_enabled(self): return self.actions["revision"].enabled - + @property def revision_automatic_enabled(self) -> bool: return self.actions["revision"].config["automatic_revision"].value - + @property def revision_automatic_targets(self) -> list[str]: return self.actions["revision"].config["automatic_revision_targets"].value - + @property def revision_method(self): return self.actions["revision"].config["revision_method"].value - + @property def revision_repetition_detection_method(self): return self.actions["revision"].config["repetition_detection_method"].value - + @property def revision_repetition_threshold(self): return self.actions["revision"].config["repetition_threshold"].value - + @property def revision_repetition_range(self): return self.actions["revision"].config["repetition_range"].value - + @property def revision_repetition_min_length(self): return self.actions["revision"].config["repetition_min_length"].value - + @property def revision_split_on_comma(self): return self.actions["revision"].config["split_on_comma"].value - + @property def revision_min_issues(self): return self.actions["revision"].config["min_issues"].value - + @property def revision_detect_bad_prose_enabled(self): return self.actions["revision"].config["detect_bad_prose"].value - + @property def revision_detect_bad_prose_threshold(self): return self.actions["revision"].config["detect_bad_prose_threshold"].value - + # signal connect - + def connect(self, scene): async_signals.get("agent.conversation.generated").connect( self.revision_on_generation @@ -393,103 +414,134 @@ class RevisionMixin: ) # connect to the super class AFTER so these run first. super().connect(scene) - - + async def revision_on_generation( - self, - emission: ConversationAgentEmission | NarratorAgentEmission | ContextualGenerateEmission | SummarizeEmission | LayeredHistoryFinalizeEmission, + self, + emission: ConversationAgentEmission + | NarratorAgentEmission + | ContextualGenerateEmission + | SummarizeEmission + | LayeredHistoryFinalizeEmission, ): """ Called when a conversation or narrator message is generated """ - + if not self.revision_enabled or not self.revision_automatic_enabled: return - - if isinstance(emission, ContextualGenerateEmission) and "contextual_generation" not in self.revision_automatic_targets: + + if ( + isinstance(emission, ContextualGenerateEmission) + and "contextual_generation" not in self.revision_automatic_targets + ): return - - if isinstance(emission, ConversationAgentEmission) and "character" not in self.revision_automatic_targets: + + if ( + isinstance(emission, ConversationAgentEmission) + and "character" not in self.revision_automatic_targets + ): return - - if isinstance(emission, NarratorAgentEmission) and "narrator" not in self.revision_automatic_targets: + + if ( + isinstance(emission, NarratorAgentEmission) + and "narrator" not in self.revision_automatic_targets + ): return - + if isinstance(emission, SummarizeEmission): - if emission.summarization_type == "dialogue" and "summarization" not in self.revision_automatic_targets: + if ( + emission.summarization_type == "dialogue" + and "summarization" not in self.revision_automatic_targets + ): return if emission.summarization_type == "events": - # event summarization is very pragmatic and doesn't really benefit + # event summarization is very pragmatic and doesn't really benefit # from revision, so we skip it return - - if isinstance(emission, LayeredHistoryFinalizeEmission) and "summarization" not in self.revision_automatic_targets: + + if ( + isinstance(emission, LayeredHistoryFinalizeEmission) + and "summarization" not in self.revision_automatic_targets + ): return - + try: if revision_disabled_context.get(): - log.debug("revision_on_generation: revision disabled through context", emission=emission) + log.debug( + "revision_on_generation: revision disabled through context", + emission=emission, + ) return except LookupError: pass - + info = RevisionInformation( - text = emission.response, - character = getattr(emission, "character", None), - context_type = getattr(emission, "context_type", None), - context_name = getattr(emission, "context_name", None), + text=emission.response, + character=getattr(emission, "character", None), + context_type=getattr(emission, "context_type", None), + context_name=getattr(emission, "context_name", None), ) - + if isinstance(emission, (SummarizeEmission, LayeredHistoryFinalizeEmission)): info.summarization_history = emission.summarization_history or [] - - if isinstance(emission, ContextualGenerateEmission) and info.context_type not in CONTEXTUAL_GENERATION_TYPES: + + if ( + isinstance(emission, ContextualGenerateEmission) + and info.context_type not in CONTEXTUAL_GENERATION_TYPES + ): return - + revised_text = await self.revision_revise(info) - + emission.response = revised_text - log.info("Revision done", type=type(emission).__name__, revised=revised_text, original=info.text) + log.info( + "Revision done", + type=type(emission).__name__, + revised=revised_text, + original=info.text, + ) # helpers - + async def revision_collect_repetition_range(self) -> list[str]: """ Collect the range of text to revise against by going through the scene's history and collecting narrator and character messages """ - - scene:"Scene" = self.scene - + + scene: "Scene" = self.scene + ctx = revision_context.get() - + messages = scene.collect_messages( typ=["narrator", "character"], max_messages=self.revision_repetition_range, - start_idx=scene.message_index(ctx.message_id) -1 if ctx.message_id else None + start_idx=scene.message_index(ctx.message_id) - 1 + if ctx.message_id + else None, ) - + return_messages = [] - + for message in messages: if isinstance(message, CharacterMessage): return_messages.append(message.without_name) else: return_messages.append(message.message) - + return return_messages # actions - + @set_processing async def revision_revise( - self, + self, info: RevisionInformation, ): """ Revise the text based on the revision method """ - + try: if self.revision_method == "dedupe": return await self.revision_dedupe(info) @@ -500,89 +552,98 @@ class RevisionMixin: except GenerationCancelled: log.warning("revision_revise: generation cancelled", text=info.text) return info.text - except Exception as e: + except Exception: import traceback + log.error("revision_revise: error", error=traceback.format_exc()) return info.text finally: info.loading_status.done() - - - async def _revision_evaluate_semantic_similarity(self, text: str, character: "Character | None" = None) -> list[SimilarityMatch]: + + async def _revision_evaluate_semantic_similarity( + self, text: str, character: "Character | None" = None + ) -> list[SimilarityMatch]: """ Detect repetition using semantic similarity """ - + memory_agent = get_agent("memory") - character_name_prefix = text.startswith(f"{character.name}: ") if character else False - + character_name_prefix = ( + text.startswith(f"{character.name}: ") if character else False + ) + if character_name_prefix: - text = text[len(character.name) + 2:] - - compare_against:list[str] = await self.revision_collect_repetition_range() - + text = text[len(character.name) + 2 :] + + compare_against: list[str] = await self.revision_collect_repetition_range() + text_sentences = compile_text_to_sentences(text) - + history_sentences = [] for sentence in compare_against: history_sentences.extend(compile_text_to_sentences(sentence)) - + min_length = self.revision_repetition_min_length - + # strip min length sentences from both lists text_sentences = [i for i in text_sentences if len(i[1]) >= min_length] history_sentences = [i for i in history_sentences if len(i[1]) >= min_length] - + result_matrix = await memory_agent.compare_string_lists( [i[1] for i in text_sentences], [i[1] for i in history_sentences], similarity_threshold=self.revision_repetition_threshold / 100, ) - + similarity_matches = [] - + for match in result_matrix["similarity_matches"]: index_text = match[0] index_history = match[1] sentence = text_sentences[index_text][1] matched = history_sentences[index_history][1] - similarity_matches.append(SimilarityMatch( - original=str(sentence), - matched=str(matched), - similarity=round(match[2] * 100, 2), - left_neighbor=text_sentences[index_text - 1][1] if index_text > 0 else None, - right_neighbor=text_sentences[index_text + 1][1] if index_text < len(text_sentences) - 1 else None, - )) - - return list(set(similarity_matches)) - - - async def _revision_evaluate_fuzzy_similarity(self, text: str, character: "Character | None" = None) -> list[SimilarityMatch]: - """ - Detect repetition using fuzzy matching and dedupe - - Will return a tuple with the deduped text and the deduped text - """ - - compare_against:list[str] = await self.revision_collect_repetition_range() - - matches = [] - - for old_text in compare_against: - matches.extend( - similarity_matches( - text, - old_text, - similarity_threshold=self.revision_repetition_threshold, - min_length=self.revision_repetition_min_length, - split_on_comma=self.revision_split_on_comma + similarity_matches.append( + SimilarityMatch( + original=str(sentence), + matched=str(matched), + similarity=round(match[2] * 100, 2), + left_neighbor=text_sentences[index_text - 1][1] + if index_text > 0 + else None, + right_neighbor=text_sentences[index_text + 1][1] + if index_text < len(text_sentences) - 1 + else None, + ) + ) + + return list(set(similarity_matches)) + + async def _revision_evaluate_fuzzy_similarity( + self, text: str, character: "Character | None" = None + ) -> list[SimilarityMatch]: + """ + Detect repetition using fuzzy matching and dedupe + + Will return a tuple with the deduped text and the deduped text + """ + + compare_against: list[str] = await self.revision_collect_repetition_range() + + matches = [] + + for old_text in compare_against: + matches.extend( + similarity_matches( + text, + old_text, + similarity_threshold=self.revision_repetition_threshold, + min_length=self.revision_repetition_min_length, + split_on_comma=self.revision_split_on_comma, ) ) - return list(set(matches)) - - + async def revision_detect_bad_prose(self, text: str) -> list[dict]: """ Detect bad prose in the text @@ -590,52 +651,60 @@ class RevisionMixin: try: sentences = compile_text_to_sentences(text) identified = [] - + writing_style = self.scene.writing_style - + if not writing_style or not writing_style.phrases: return [] - + if self.revision_split_on_comma: - sentences = split_sentences_on_comma([sentence[0] for sentence in sentences]) - + sentences = split_sentences_on_comma( + [sentence[0] for sentence in sentences] + ) + # collect all phrases by method semantic_similarity_phrases = [] regex_phrases = [] - + for phrase in writing_style.phrases: if not phrase.phrase or not phrase.instructions or not phrase.active: continue - + if phrase.match_method == "semantic_similarity": semantic_similarity_phrases.append(phrase) elif phrase.match_method == "regex": regex_phrases.append(phrase) - + # evaulate regex phrases first for phrase in regex_phrases: for sentence in sentences: - identified.extend(await self._revision_detect_bad_prose_regex(sentence, phrase)) - + identified.extend( + await self._revision_detect_bad_prose_regex(sentence, phrase) + ) + # next evaulate semantic similarity phrases at once identified.extend( - await self._revision_detect_bad_prose_semantic_similarity(sentences, semantic_similarity_phrases) + await self._revision_detect_bad_prose_semantic_similarity( + sentences, semantic_similarity_phrases + ) ) return identified except Exception as e: log.error("revision_detect_bad_prose: error", error=e) return [] - - async def _revision_detect_bad_prose_semantic_similarity(self, sentences: list[str], phrases: list[PhraseDetection]) -> list[dict]: + + async def _revision_detect_bad_prose_semantic_similarity( + self, sentences: list[str], phrases: list[PhraseDetection] + ) -> list[dict]: """ Detect bad prose in the text using semantic similarity """ - + memory_agent = get_agent("memory") - + if not memory_agent: return [] - + """ Compare two lists of strings using the current embedding function without touching the database. @@ -646,46 +715,51 @@ class RevisionMixin: - 'distance_matches': list of (i, j, distance) (filtered if threshold set, otherwise all) """ threshold = self.revision_detect_bad_prose_threshold - + phrase_strings = [phrase.phrase for phrase in phrases] - + num_comparisons = len(sentences) * len(phrase_strings) - - log.debug("revision_detect_bad_prose: comparing sentences to phrases", num_comparisons=num_comparisons) - - result_matrix = await memory_agent.compare_string_lists( - sentences, - phrase_strings, - similarity_threshold=threshold + + log.debug( + "revision_detect_bad_prose: comparing sentences to phrases", + num_comparisons=num_comparisons, ) - + + result_matrix = await memory_agent.compare_string_lists( + sentences, phrase_strings, similarity_threshold=threshold + ) + result = [] - + for match in result_matrix["similarity_matches"]: sentence = sentences[match[0]] phrase = phrases[match[1]] - result.append({ - "phrase": sentence, - "instructions": phrase.instructions, - "reason": "Unwanted phrase found", - "matched": phrase.phrase, - "method": "semantic_similarity", - "similarity": match[2], - }) - + result.append( + { + "phrase": sentence, + "instructions": phrase.instructions, + "reason": "Unwanted phrase found", + "matched": phrase.phrase, + "method": "semantic_similarity", + "similarity": match[2], + } + ) + return result - - async def _revision_detect_bad_prose_regex(self, sentence: str, phrase: PhraseDetection) -> list[dict]: + + async def _revision_detect_bad_prose_regex( + self, sentence: str, phrase: PhraseDetection + ) -> list[dict]: """ Detect bad prose in the text using regex """ if str(phrase.classification).lower() != "unwanted": return [] - + pattern = re.compile(phrase.phrase) if not pattern.search(sentence, re.IGNORECASE): return [] - + return [ { "phrase": sentence, @@ -697,8 +771,8 @@ class RevisionMixin: ] async def revision_collect_issues( - self, - text: str, + self, + text: str, character: "Character | None" = None, detect_bad_prose: bool = True, ) -> Issues: @@ -707,38 +781,47 @@ class RevisionMixin: """ writing_style = self.scene.writing_style detect_bad_prose = ( - self.revision_detect_bad_prose_enabled and - writing_style and - detect_bad_prose + self.revision_detect_bad_prose_enabled + and writing_style + and detect_bad_prose ) - + repetition_log = [] bad_prose_log = [] - + repetition = [] bad_prose = [] - + # Step 1 - Detect repetition if self.revision_repetition_detection_method == "fuzzy": - repetition_matches = await self._revision_evaluate_fuzzy_similarity(text, character) + repetition_matches = await self._revision_evaluate_fuzzy_similarity( + text, character + ) elif self.revision_repetition_detection_method == "semantic_similarity": - repetition_matches = await self._revision_evaluate_semantic_similarity(text, character) - + repetition_matches = await self._revision_evaluate_semantic_similarity( + text, character + ) + for match in repetition_matches: - repetition.append({ - "text_a": match.original, - "text_b": match.matched, - "similarity": match.similarity - }) - repetition_log.append(f"Repetition: `{match.original}` -> `{match.matched}` (similarity: {match.similarity})") - + repetition.append( + { + "text_a": match.original, + "text_b": match.matched, + "similarity": match.similarity, + } + ) + repetition_log.append( + f"Repetition: `{match.original}` -> `{match.matched}` (similarity: {match.similarity})" + ) + # Step 2 - Detect bad prose if detect_bad_prose: bad_prose = await self.revision_detect_bad_prose(text) for identified in bad_prose: - bad_prose_log.append(f"Bad prose: `{identified['phrase']}` (reason: {identified['reason']}, matched: {identified['matched']}, instructions: {identified['instructions']})") - - + bad_prose_log.append( + f"Bad prose: `{identified['phrase']}` (reason: {identified['reason']}, matched: {identified['matched']}, instructions: {identified['instructions']})" + ) + return Issues( repetition=repetition, repetition_matches=repetition_matches, @@ -746,148 +829,168 @@ class RevisionMixin: repetition_log=repetition_log, bad_prose_log=bad_prose_log, ) - async def revision_dedupe( - self, + self, info: RevisionInformation, ) -> str: """ Revise the text by deduping """ - + info.revision_method = "dedupe" - + text = info.text character = info.character original_text = text - character_name_prefix = text.startswith(f"{character.name}: ") if character else False + character_name_prefix = ( + text.startswith(f"{character.name}: ") if character else False + ) if character_name_prefix: - text = text[len(character.name) + 2:] - + text = text[len(character.name) + 2 :] + original_length = len(text) - - issues = await self.revision_collect_issues(text, character, detect_bad_prose=False) - + + issues = await self.revision_collect_issues( + text, character, detect_bad_prose=False + ) + if not issues.repetition_matches: return original_text - + emission = RevisionEmission(agent=self, info=info, issues=issues) - + await async_signals.get("agent.editor.revision-revise.before").send(emission) - - emission.response = dedupe_sentences_from_matches(text, issues.repetition_matches) - + + emission.response = dedupe_sentences_from_matches( + text, issues.repetition_matches + ) + await async_signals.get("agent.editor.revision-revise.after").send(emission) - + text = emission.response - + # remove empty quotes and asterisks - text = text.replace("\"\"", "").replace("**", "") - + text = text.replace('""', "").replace("**", "") + deduped_length = len(text) # calculate reduction percentage reduction = round((original_length - deduped_length) / original_length * 100, 2) if reduction > 90: - log.warning("revision_dedupe: reduction is too high, reverting to original text", original_text=original_text, reduction=reduction) - emit("agent_message", - message=f"No text remained after dedupe, reverting to original text - similarity threshold is likely too low.", + log.warning( + "revision_dedupe: reduction is too high, reverting to original text", + original_text=original_text, + reduction=reduction, + ) + emit( + "agent_message", + message="No text remained after dedupe, reverting to original text - similarity threshold is likely too low.", data={ "uuid": str(uuid.uuid4()), "agent": "editor", "header": "Aborted dedupe", "color": "red", - }, + }, meta={ "action": "revision_dedupe", "threshold": self.revision_repetition_threshold, "range": self.revision_repetition_range, }, - websocket_passthrough=True + websocket_passthrough=True, ) return original_text - + if character_name_prefix: text = f"{character.name}: {text}" - + for dedupe in issues.repetition: - text_a = dedupe['text_a'] - text_b = dedupe['text_b'] - + text_a = dedupe["text_a"] + text_b = dedupe["text_b"] + message = f"{text_a} -> {text_b}" - emit("agent_message", + emit( + "agent_message", message=message, data={ "uuid": str(uuid.uuid4()), "agent": "editor", "header": "Removed repetition", "color": "highlight4", - }, + }, meta={ "action": "revision_dedupe", - "similarity": dedupe['similarity'], + "similarity": dedupe["similarity"], "threshold": self.revision_repetition_threshold, "range": self.revision_repetition_range, }, - websocket_passthrough=True + websocket_passthrough=True, ) - + return text - + async def revision_rewrite( - self, + self, info: RevisionInformation, ) -> str: """ Revise the text by rewriting """ - + text = info.text character = info.character loading_status = info.loading_status original_text = text - - character_name_prefix = text.startswith(f"{character.name}: ") if character else False + + character_name_prefix = ( + text.startswith(f"{character.name}: ") if character else False + ) if character_name_prefix: - text = text[len(character.name) + 2:] - + text = text[len(character.name) + 2 :] + issues = await self.revision_collect_issues(text, character) - + if loading_status: loading_status.max_steps = 2 - + num_issues = len(issues.log) - + if not num_issues: return original_text - + if num_issues < self.revision_min_issues: - log.debug("revision_rewrite: not enough issues found, returning original text", issues=num_issues, min_issues=self.revision_min_issues) + log.debug( + "revision_rewrite: not enough issues found, returning original text", + issues=num_issues, + min_issues=self.revision_min_issues, + ) # Not enough issues found, return original text await self.emit_message( "Aborted rewrite", message=[ {"subtitle": "Issues", "content": issues.log}, - {"subtitle": "Message", "content": f"Not enough issues found, returning original text - minimum issues is {self.revision_min_issues}. Found {num_issues} issues."}, + { + "subtitle": "Message", + "content": f"Not enough issues found, returning original text - minimum issues is {self.revision_min_issues}. Found {num_issues} issues.", + }, ], color="orange", ) return original_text - + # Step 4 - Rewrite token_count = count_tokens(text) - + log.debug("revision_rewrite: token_count", token_count=token_count) - + if loading_status: loading_status("Editor - Issues identified, analyzing text...") emission = RevisionEmission( - agent=self, - info=info, + agent=self, + info=info, issues=issues, ) @@ -903,28 +1006,25 @@ class RevisionMixin: "context_type": info.context_type, "context_name": info.context_name, } - - - await async_signals.get("agent.editor.revision-revise.before").send( - emission - ) + + await async_signals.get("agent.editor.revision-revise.before").send(emission) await async_signals.get("agent.editor.revision-analysis.before").send(emission) - + analysis = await Prompt.request( "editor.revision-analysis", self.client, - f"edit_768", + "edit_768", vars=emission.template_vars, dedupe_enabled=False, ) - - async def rewrite_text(text:str) -> str: + + async def rewrite_text(text: str) -> str: return text - + emission.response = analysis await async_signals.get("agent.editor.revision-analysis.after").send(emission) analysis = emission.response - + focal_handler = focal.Focal( self.client, callbacks=[ @@ -943,24 +1043,24 @@ class RevisionMixin: analysis=analysis, text=text, ) - + if loading_status: loading_status("Editor - Rewriting text...") await focal_handler.request( "editor.revision-rewrite", ) - + try: revision = focal_handler.state.calls[0].result except Exception as e: log.error("revision_rewrite: error", error=e) return original_text - + emission.response = revision await async_signals.get("agent.editor.revision-revise.after").send(emission) revision = emission.response - + diff = dmp_inline_diff(text, revision) await self.emit_message( "Rewrite", @@ -981,54 +1081,53 @@ class RevisionMixin: }, color="highlight4", ) - + if character_name_prefix and not revision.startswith(f"{character.name}: "): revision = f"{character.name}: {revision}" - + return revision async def revision_unslop( - self, + self, info: RevisionInformation, response_length: int = 768, ) -> str: """ Unslop the text """ - + text = info.text character = info.character - + original_text = text - - character_name_prefix = text.startswith(f"{character.name}: ") if character else False + + character_name_prefix = ( + text.startswith(f"{character.name}: ") if character else False + ) if character_name_prefix: - text = text[len(character.name) + 2:] - + text = text[len(character.name) + 2 :] + issues = await self.revision_collect_issues(text, character) - summarizer = get_agent("summarizer") scene_analysis = await summarizer.get_cached_analysis( "conversation" if character else "narration" ) - + template = "editor.unslop" if info.context_type: - template = f"editor.unslop-contextual-generation" + template = "editor.unslop-contextual-generation" elif info.summarization_history is not None: template = "editor.unslop-summarization" - - log.debug("revision_unslop: issues", issues=issues, template=template) - - + log.debug("revision_unslop: issues", issues=issues, template=template) + emission = RevisionEmission( agent=self, info=info, issues=issues, ) - + emission.template_vars = { "text": text, "scene_analysis": scene_analysis, @@ -1043,9 +1142,9 @@ class RevisionMixin: "context_name": info.context_name, "summarization_history": info.summarization_history, } - + await async_signals.get("agent.editor.revision-revise.before").send(emission) - + response = await Prompt.request( template, self.client, @@ -1053,31 +1152,34 @@ class RevisionMixin: vars=emission.template_vars, dedupe_enabled=False, ) - + # extract ... - + if "" not in response: log.debug("revision_unslop: no found in response", response=response) return original_text fix = response.split("", 1)[1] - + if "" in fix: fix = fix.split("", 1)[0] elif "<" in fix: - log.error("revision_unslop: no found in response, but other tags found, aborting.", response=response) + log.error( + "revision_unslop: no found in response, but other tags found, aborting.", + response=response, + ) return original_text - + if not fix: log.error("revision_unslop: no fix found", response=response) return original_text - + fix = fix.strip() - + emission.response = fix await async_signals.get("agent.editor.revision-revise.after").send(emission) fix = emission.response - + # send diff to user diff = dmp_inline_diff(text, fix) await self.emit_message( @@ -1092,18 +1194,18 @@ class RevisionMixin: }, color="highlight4", ) - + if character_name_prefix and not fix.startswith(f"{character.name}: "): fix = f"{character.name}: {fix}" - + return fix - + def inject_prompt_paramters( self, prompt_param: dict, kind: str, agent_function_name: str ): super().inject_prompt_paramters(prompt_param, kind, agent_function_name) - + if agent_function_name == "revision_revise": if prompt_param.get("extra_stopping_strings") is None: prompt_param["extra_stopping_strings"] = [] - prompt_param["extra_stopping_strings"] += [""] \ No newline at end of file + prompt_param["extra_stopping_strings"] += [""] diff --git a/src/talemate/agents/editor/websocket_handler.py b/src/talemate/agents/editor/websocket_handler.py index f38d5c66..194d4e5a 100644 --- a/src/talemate/agents/editor/websocket_handler.py +++ b/src/talemate/agents/editor/websocket_handler.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING from talemate.instance import get_agent from talemate.server.websocket_plugin import Plugin -from talemate.status import set_loading from talemate.scene_message import CharacterMessage from talemate.agents.editor.revision import RevisionContext, RevisionInformation @@ -17,47 +16,49 @@ __all__ = [ log = structlog.get_logger("talemate.server.editor") + class RevisionPayload(pydantic.BaseModel): message_id: int + class EditorWebsocketHandler(Plugin): """ Handles editor actions """ - + router = "editor" - + @property def editor(self): return get_agent("editor") - + async def handle_request_revision(self, data: dict): """ Generate clickable actions for the user """ - + editor = self.editor - scene:"Scene" = self.scene - + scene: "Scene" = self.scene + if not editor.revision_enabled: raise Exception("Revision is not enabled") - + payload = RevisionPayload(**data) message = scene.get_message(payload.message_id) - + character = None - + if isinstance(message, CharacterMessage): character = scene.get_character(message.character_name) - + if not message: raise Exception("Message not found") - + with RevisionContext(message.id): info = RevisionInformation( text=message.message, character=character, ) revised = await editor.revision_revise(info) - - scene.edit_message(message.id, revised) \ No newline at end of file + + scene.edit_message(message.id, revised) diff --git a/src/talemate/agents/memory/__init__.py b/src/talemate/agents/memory/__init__.py index 8f34c1cb..446d7a92 100644 --- a/src/talemate/agents/memory/__init__.py +++ b/src/talemate/agents/memory/__init__.py @@ -75,10 +75,9 @@ class MemoryAgent(Agent): @classmethod def init_actions(cls, presets: list[dict] | None = None) -> dict[str, AgentAction]: - if presets is None: presets = [] - + actions = { "_config": AgentAction( enabled=True, @@ -101,23 +100,25 @@ class MemoryAgent(Agent): choices=[ {"value": "cpu", "label": "CPU"}, {"value": "cuda", "label": "CUDA"}, - ] + ], ), }, ), } return actions - + def __init__(self, scene, **kwargs): self.db = None self.scene = scene self.memory_tracker = {} self.config = load_config() self._ready_to_add = False - + handlers["config_saved"].connect(self.on_config_saved) - async_signals.get("client.embeddings_available").connect(self.on_client_embeddings_available) - + async_signals.get("client.embeddings_available").connect( + self.on_client_embeddings_available + ) + self.actions = MemoryAgent.init_actions(presets=self.get_presets) @property @@ -132,30 +133,32 @@ class MemoryAgent(Agent): @property def db_name(self): raise NotImplementedError() - + @property def get_presets(self): - - def _label(embedding:dict): - prefix = embedding['client'] if embedding['client'] else embedding['embeddings'] - if embedding['model']: + def _label(embedding: dict): + prefix = ( + embedding["client"] if embedding["client"] else embedding["embeddings"] + ) + if embedding["model"]: return f"{prefix}: {embedding['model']}" else: return f"{prefix}" - + return [ - {"value": k, "label": _label(v)} for k,v in self.config.get("presets", {}).get("embeddings", {}).items() + {"value": k, "label": _label(v)} + for k, v in self.config.get("presets", {}).get("embeddings", {}).items() ] - + @property def embeddings_config(self): _embeddings = self.actions["_config"].config["embeddings"].value return self.config.get("presets", {}).get("embeddings", {}).get(_embeddings, {}) - + @property def embeddings(self): return self.embeddings_config.get("embeddings", "sentence-transformer") - + @property def using_openai_embeddings(self): return self.embeddings == "openai" @@ -163,39 +166,34 @@ class MemoryAgent(Agent): @property def using_instructor_embeddings(self): return self.embeddings == "instructor" - + @property def using_sentence_transformer_embeddings(self): return self.embeddings == "default" or self.embeddings == "sentence-transformer" - + @property def using_client_api_embeddings(self): return self.embeddings == "client-api" - + @property def using_local_embeddings(self): - return self.embeddings in [ - "instructor", - "sentence-transformer", - "default" - ] - - + return self.embeddings in ["instructor", "sentence-transformer", "default"] + @property def embeddings_client(self): return self.embeddings_config.get("client") - + @property def max_distance(self) -> float: distance = float(self.embeddings_config.get("distance", 1.0)) distance_mod = float(self.embeddings_config.get("distance_mod", 1.0)) - + return distance * distance_mod - + @property def model(self): return self.embeddings_config.get("model") - + @property def distance_function(self): return self.embeddings_config.get("distance_function", "l2") @@ -213,25 +211,28 @@ class MemoryAgent(Agent): """ Returns a unique fingerprint for the current configuration """ - - model_name = self.model.replace('/','-') if self.model else "none" - - return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower() + + model_name = self.model.replace("/", "-") if self.model else "none" + + return f"{self.embeddings}-{model_name}-{self.distance_function}-{self.device}-{self.trust_remote_code}".lower() async def apply_config(self, *args, **kwargs): - _fingerprint = self.fingerprint - + await super().apply_config(*args, **kwargs) - + fingerprint_changed = _fingerprint != self.fingerprint - + # have embeddings or device changed? if fingerprint_changed: - log.warning("memory agent", status="embedding function changed", old=_fingerprint, new=self.fingerprint) + log.warning( + "memory agent", + status="embedding function changed", + old=_fingerprint, + new=self.fingerprint, + ) await self.handle_embeddings_change() - - + @set_processing async def handle_embeddings_change(self): scene = active_scene.get() @@ -239,10 +240,10 @@ class MemoryAgent(Agent): # if sentence-transformer and no model-name, set embeddings to default if self.using_sentence_transformer_embeddings and not self.model: self.actions["_config"].config["embeddings"].value = "default" - + if not scene or not scene.get_helper("memory"): return - + self.close_db(scene) emit("status", "Re-importing context database", status="busy") await scene.commit_to_memory() @@ -257,38 +258,49 @@ class MemoryAgent(Agent): def on_config_saved(self, event): loop = asyncio.get_running_loop() openai_key = self.openai_api_key - + fingerprint = self.fingerprint - + old_presets = self.actions["_config"].config["embeddings"].choices.copy() - + self.config = load_config() new_presets = self.sync_presets() if fingerprint != self.fingerprint: - log.warning("memory agent", status="embedding function changed", old=fingerprint, new=self.fingerprint) + log.warning( + "memory agent", + status="embedding function changed", + old=fingerprint, + new=self.fingerprint, + ) loop.run_until_complete(self.handle_embeddings_change()) - + emit_status = False - + if openai_key != self.openai_api_key: emit_status = True - + if old_presets != new_presets: emit_status = True - + if emit_status: loop.run_until_complete(self.emit_status()) - async def on_client_embeddings_available(self, event: "ClientEmbeddingsStatus"): current_embeddings = self.actions["_config"].config["embeddings"].value - + if current_embeddings == event.client.embeddings_identifier: return - + if not self.using_client_api_embeddings or not self.ready: - log.warning("memory agent - client embeddings available", status="changing embeddings", old=current_embeddings, new=event.client.embeddings_identifier) - self.actions["_config"].config["embeddings"].value = event.client.embeddings_identifier + log.warning( + "memory agent - client embeddings available", + status="changing embeddings", + old=current_embeddings, + new=event.client.embeddings_identifier, + ) + self.actions["_config"].config[ + "embeddings" + ].value = event.client.embeddings_identifier await self.emit_status() await self.handle_embeddings_change() await self.save_config() @@ -301,13 +313,16 @@ class MemoryAgent(Agent): except EmbeddingsModelLoadError: raise except Exception as e: - log.error("memory agent", error="failed to set db", details=traceback.format_exc()) - - if "torchvision::nms does not exist" in str(e): - raise SetDBError("The embeddings you are trying to use require the `torchvision` package to be installed") - - raise SetDBError(str(e)) + log.error( + "memory agent", error="failed to set db", details=traceback.format_exc() + ) + if "torchvision::nms does not exist" in str(e): + raise SetDBError( + "The embeddings you are trying to use require the `torchvision` package to be installed" + ) + + raise SetDBError(str(e)) def close_db(self): raise NotImplementedError() @@ -426,9 +441,9 @@ class MemoryAgent(Agent): with MemoryRequest(query=text, query_params=query) as active_memory_request: active_memory_request.max_distance = self.max_distance return await asyncio.to_thread(self._get, text, character, **query) - #return await loop.run_in_executor( + # return await loop.run_in_executor( # None, functools.partial(self._get, text, character, **query) - #) + # ) def _get(self, text, character=None, **query): raise NotImplementedError() @@ -528,9 +543,7 @@ class MemoryAgent(Agent): continue # Fetch potential memories for this query. - raw_results = await self.get( - formatter(query), limit=limit, **where - ) + raw_results = await self.get(formatter(query), limit=limit, **where) # Apply filter and respect the `iterate` limit for this query. accepted: list[str] = [] @@ -591,9 +604,9 @@ class MemoryAgent(Agent): Returns a dictionary with 'cosine_similarity' and 'euclidean_distance'. """ - + embed_fn = self.embedding_function - + # Embed the two strings vec1 = np.array(embed_fn([string1])[0]) vec2 = np.array(embed_fn([string2])[0]) @@ -604,17 +617,14 @@ class MemoryAgent(Agent): # Compute Euclidean distance euclidean_dist = np.linalg.norm(vec1 - vec2) - return { - "cosine_similarity": cosine_sim, - "euclidean_distance": euclidean_dist - } + return {"cosine_similarity": cosine_sim, "euclidean_distance": euclidean_dist} async def compare_string_lists( self, list_a: list[str], list_b: list[str], similarity_threshold: float = None, - distance_threshold: float = None + distance_threshold: float = None, ) -> dict: """ Compare two lists of strings using the current embedding function without touching the database. @@ -625,15 +635,21 @@ class MemoryAgent(Agent): - 'similarity_matches': list of (i, j, score) (filtered if threshold set, otherwise all) - 'distance_matches': list of (i, j, distance) (filtered if threshold set, otherwise all) """ - if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None: - raise RuntimeError("Embedding function is not initialized. Make sure the database is set.") + if ( + not self.db + or not hasattr(self.db, "_embedding_function") + or self.db._embedding_function is None + ): + raise RuntimeError( + "Embedding function is not initialized. Make sure the database is set." + ) if not list_a or not list_b: return { "cosine_similarity_matrix": np.array([]), "euclidean_distance_matrix": np.array([]), "similarity_matches": [], - "distance_matches": [] + "distance_matches": [], } embed_fn = self.db._embedding_function @@ -653,21 +669,29 @@ class MemoryAgent(Agent): cosine_similarity_matrix = np.dot(vecs_a_norm, vecs_b_norm.T) # Euclidean distance matrix - a_squared = np.sum(vecs_a ** 2, axis=1).reshape(-1, 1) - b_squared = np.sum(vecs_b ** 2, axis=1).reshape(1, -1) - euclidean_distance_matrix = np.sqrt(a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T)) + a_squared = np.sum(vecs_a**2, axis=1).reshape(-1, 1) + b_squared = np.sum(vecs_b**2, axis=1).reshape(1, -1) + euclidean_distance_matrix = np.sqrt( + a_squared + b_squared - 2 * np.dot(vecs_a, vecs_b.T) + ) # Prepare matches similarity_matches = [] distance_matches = [] # Populate similarity matches - sim_indices = np.argwhere(cosine_similarity_matrix >= (similarity_threshold if similarity_threshold is not None else -np.inf)) + sim_indices = np.argwhere( + cosine_similarity_matrix + >= (similarity_threshold if similarity_threshold is not None else -np.inf) + ) for i, j in sim_indices: similarity_matches.append((i, j, cosine_similarity_matrix[i, j])) # Populate distance matches - dist_indices = np.argwhere(euclidean_distance_matrix <= (distance_threshold if distance_threshold is not None else np.inf)) + dist_indices = np.argwhere( + euclidean_distance_matrix + <= (distance_threshold if distance_threshold is not None else np.inf) + ) for i, j in dist_indices: distance_matches.append((i, j, euclidean_distance_matrix[i, j])) @@ -675,9 +699,10 @@ class MemoryAgent(Agent): "cosine_similarity_matrix": cosine_similarity_matrix, "euclidean_distance_matrix": euclidean_distance_matrix, "similarity_matches": similarity_matches, - "distance_matches": distance_matches + "distance_matches": distance_matches, } - + + @register(condition=lambda: chromadb is not None) class ChromaDBMemoryAgent(MemoryAgent): requires_llm_client = False @@ -690,32 +715,34 @@ class ChromaDBMemoryAgent(MemoryAgent): if getattr(self, "db_client", None): return True return False - + @property def client_api_ready(self) -> bool: if self.using_client_api_embeddings: - embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client) + embeddings_client: ClientBase | None = instance.get_client( + self.embeddings_client + ) if not embeddings_client: return False - + if not embeddings_client.supports_embeddings: return False - + if not embeddings_client.embeddings_status: return False - + if embeddings_client.current_status not in ["idle", "busy"]: return False - + return True - + return False @property def status(self): if self.using_client_api_embeddings and not self.client_api_ready: return "error" - + if self.ready: return "active" if not getattr(self, "processing", False) else "busy" @@ -726,7 +753,6 @@ class ChromaDBMemoryAgent(MemoryAgent): @property def agent_details(self): - details = { "backend": AgentDetail( icon="mdi-server-outline", @@ -738,23 +764,22 @@ class ChromaDBMemoryAgent(MemoryAgent): value=self.embeddings, description="The embeddings type.", ).model_dump(), - } - + if self.model: details["model"] = AgentDetail( icon="mdi-brain", value=self.model, description="The embeddings model.", ).model_dump() - + if self.embeddings_client: details["client"] = AgentDetail( icon="mdi-network-outline", value=self.embeddings_client, description="The client to use for embeddings.", ).model_dump() - + if self.using_local_embeddings: details["device"] = AgentDetail( icon="mdi-memory", @@ -770,9 +795,11 @@ class ChromaDBMemoryAgent(MemoryAgent): "description": "You must provide an OpenAI API key to use OpenAI embeddings", "color": "error", } - + if self.using_client_api_embeddings: - embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client) + embeddings_client: ClientBase | None = instance.get_client( + self.embeddings_client + ) if not embeddings_client: details["error"] = { @@ -784,7 +811,7 @@ class ChromaDBMemoryAgent(MemoryAgent): return details client_name = embeddings_client.name - + if not embeddings_client.supports_embeddings: error_message = f"{client_name} does not support embeddings" elif embeddings_client.current_status not in ["idle", "busy"]: @@ -793,7 +820,7 @@ class ChromaDBMemoryAgent(MemoryAgent): error_message = f"{client_name} has no embeddings model loaded" else: error_message = None - + if error_message: details["error"] = { "icon": "mdi-alert", @@ -814,8 +841,14 @@ class ChromaDBMemoryAgent(MemoryAgent): @property def embedding_function(self) -> Callable: - if not self.db or not hasattr(self.db, "_embedding_function") or self.db._embedding_function is None: - raise RuntimeError("Embedding function is not initialized. Make sure the database is set.") + if ( + not self.db + or not hasattr(self.db, "_embedding_function") + or self.db._embedding_function is None + ): + raise RuntimeError( + "Embedding function is not initialized. Make sure the database is set." + ) embed_fn = self.db._embedding_function return embed_fn @@ -823,18 +856,18 @@ class ChromaDBMemoryAgent(MemoryAgent): def make_collection_name(self, scene) -> str: # generate plain text collection name collection_name = f"{self.fingerprint}" - + # chromadb collection names have the following rules: # Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address # Step 1: Hash the input string using MD5 md5_hash = hashlib.md5(collection_name.encode()).hexdigest() - + # Step 2: Ensure the result is exactly 32 characters long hashed_collection_name = md5_hash[:32] - + return f"{scene.memory_id}-tm-{hashed_collection_name}" - + async def count(self): await asyncio.sleep(0) return self.db.count() @@ -853,14 +886,17 @@ class ChromaDBMemoryAgent(MemoryAgent): self.collection_name = collection_name = self.make_collection_name(self.scene) log.info( - "chromadb agent", status="setting up db", collection_name=collection_name, embeddings=self.embeddings + "chromadb agent", + status="setting up db", + collection_name=collection_name, + embeddings=self.embeddings, ) - + distance_function = self.distance_function collection_metadata = {"hnsw:space": distance_function} - device = self.actions["_config"].config["device"].value + device = self.actions["_config"].config["device"].value model_name = self.model - + if self.using_openai_embeddings: if not openai_key: raise ValueError( @@ -878,7 +914,9 @@ class ChromaDBMemoryAgent(MemoryAgent): model_name=model_name, ) self.db = self.db_client.get_or_create_collection( - collection_name, embedding_function=openai_ef, metadata=collection_metadata + collection_name, + embedding_function=openai_ef, + metadata=collection_metadata, ) elif self.using_client_api_embeddings: log.info( @@ -886,20 +924,26 @@ class ChromaDBMemoryAgent(MemoryAgent): embeddings="Client API", client=self.embeddings_client, ) - - embeddings_client:ClientBase | None = instance.get_client(self.embeddings_client) + + embeddings_client: ClientBase | None = instance.get_client( + self.embeddings_client + ) if not embeddings_client: - raise ValueError(f"Client API embeddings client {self.embeddings_client} not found") - + raise ValueError( + f"Client API embeddings client {self.embeddings_client} not found" + ) + if not embeddings_client.supports_embeddings: - raise ValueError(f"Client API embeddings client {self.embeddings_client} does not support embeddings") - + raise ValueError( + f"Client API embeddings client {self.embeddings_client} does not support embeddings" + ) + ef = embeddings_client.embeddings_function - + self.db = self.db_client.get_or_create_collection( collection_name, embedding_function=ef, metadata=collection_metadata ) - + elif self.using_instructor_embeddings: log.info( "chromadb", @@ -909,7 +953,9 @@ class ChromaDBMemoryAgent(MemoryAgent): ) ef = embedding_functions.InstructorEmbeddingFunction( - model_name=model_name, device=device, instruction="Represent the document for retrieval:" + model_name=model_name, + device=device, + instruction="Represent the document for retrieval:", ) log.info("chromadb", status="embedding function ready") @@ -919,25 +965,26 @@ class ChromaDBMemoryAgent(MemoryAgent): log.info("chromadb", status="instructor db ready") else: log.info( - "chromadb", - embeddings="SentenceTransformer", + "chromadb", + embeddings="SentenceTransformer", model=model_name, device=device, - distance_function=distance_function + distance_function=distance_function, ) - + try: ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name=model_name, trust_remote_code=self.trust_remote_code, - device=device + device=device, ) except ValueError as e: if "`trust_remote_code=True` to remove this error" in str(e): - raise EmbeddingsModelLoadError(model_name, "Model requires `Trust remote code` to be enabled") + raise EmbeddingsModelLoadError( + model_name, "Model requires `Trust remote code` to be enabled" + ) raise EmbeddingsModelLoadError(model_name, str(e)) - - + self.db = self.db_client.get_or_create_collection( collection_name, embedding_function=ef, metadata=collection_metadata ) @@ -989,9 +1036,7 @@ class ChromaDBMemoryAgent(MemoryAgent): try: self.db_client.delete_collection(collection_name) except chromadb.errors.NotFoundError as exc: - log.error( - "chromadb agent", error="collection not found", details=exc - ) + log.error("chromadb agent", error="collection not found", details=exc) except ValueError as exc: log.error( "chromadb agent", error="failed to delete collection", details=exc @@ -1049,16 +1094,18 @@ class ChromaDBMemoryAgent(MemoryAgent): if not objects: return - + # track seen documents by id seen_ids = set() for obj in objects: if obj["id"] in seen_ids: - log.warning("chromadb agent", status="duplicate id discarded", id=obj["id"]) + log.warning( + "chromadb agent", status="duplicate id discarded", id=obj["id"] + ) continue seen_ids.add(obj["id"]) - + documents.append(obj["text"]) meta = obj.get("meta", {}) source = meta.get("source", "talemate") @@ -1071,7 +1118,7 @@ class ChromaDBMemoryAgent(MemoryAgent): metadatas.append(meta) uid = obj.get("id", f"{character}-{self.memory_tracker[character]}") ids.append(uid) - + self.db.upsert(documents=documents, metadatas=metadatas, ids=ids) def _delete(self, meta: dict): @@ -1081,11 +1128,11 @@ class ChromaDBMemoryAgent(MemoryAgent): return where = {"$and": [{k: v} for k, v in meta.items()]} - + # if there is only one item in $and reduce it to the key value pair if len(where["$and"]) == 1: where = where["$and"][0] - + self.db.delete(where=where) log.debug("chromadb agent delete", meta=meta, where=where) @@ -1122,16 +1169,16 @@ class ChromaDBMemoryAgent(MemoryAgent): log.error("chromadb agent", error="failed to query", details=e) return [] - #import json - #print(json.dumps(_results["ids"], indent=2)) - #print(json.dumps(_results["distances"], indent=2)) + # import json + # print(json.dumps(_results["ids"], indent=2)) + # print(json.dumps(_results["distances"], indent=2)) results = [] max_distance = self.max_distance - + closest = None - + active_memory_request = memory_request.get() for i in range(len(_results["distances"][0])): @@ -1139,7 +1186,7 @@ class ChromaDBMemoryAgent(MemoryAgent): doc = _results["documents"][0][i] meta = _results["metadatas"][0][i] - + active_memory_request.add_result(doc, distance, meta) if not meta: @@ -1150,7 +1197,7 @@ class ChromaDBMemoryAgent(MemoryAgent): # skip pin_only entries if meta.get("pin_only", False): continue - + if closest is None: closest = {"distance": distance, "doc": doc} elif distance < closest["distance"]: @@ -1172,13 +1219,13 @@ class ChromaDBMemoryAgent(MemoryAgent): if len(results) > limit: break - + log.debug("chromadb agent get", closest=closest, max_distance=max_distance) self.last_query = { "query": text, "closest": closest, } - + return results def convert_ts_to_date_prefix(self, ts): @@ -1197,10 +1244,9 @@ class ChromaDBMemoryAgent(MemoryAgent): return None def _get_document(self, id) -> dict: - if not id: return {} - + result = self.db.get(ids=[id] if isinstance(id, str) else id) documents = {} diff --git a/src/talemate/agents/memory/context.py b/src/talemate/agents/memory/context.py index b0404f0c..d1aa9b67 100644 --- a/src/talemate/agents/memory/context.py +++ b/src/talemate/agents/memory/context.py @@ -1,5 +1,5 @@ """ -Context manager that collects and tracks memory agent requests +Context manager that collects and tracks memory agent requests for profiling and debugging purposes """ @@ -11,91 +11,118 @@ import time from talemate.emit import emit from talemate.agents.context import active_agent -__all__ = [ - "MemoryRequest", - "start_memory_request" - "MemoryRequestState" - "memory_request" -] +__all__ = ["MemoryRequest"] log = structlog.get_logger() DEBUG_MEMORY_REQUESTS = False + class MemoryRequestResult(pydantic.BaseModel): doc: str distance: float meta: dict = pydantic.Field(default_factory=dict) - + + class MemoryRequestState(pydantic.BaseModel): - query:str + query: str results: list[MemoryRequestResult] = pydantic.Field(default_factory=list) accepted_results: list[MemoryRequestResult] = pydantic.Field(default_factory=list) query_params: dict = pydantic.Field(default_factory=dict) closest_distance: float | None = None furthest_distance: float | None = None max_distance: float | None = None - - def add_result(self, doc:str, distance:float, meta:dict): - + + def add_result(self, doc: str, distance: float, meta: dict): if doc is None: return - + self.results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta)) - self.closest_distance = min(self.closest_distance, distance) if self.closest_distance is not None else distance - self.furthest_distance = max(self.furthest_distance, distance) if self.furthest_distance is not None else distance - - def accept_result(self, doc:str, distance:float, meta:dict): - + self.closest_distance = ( + min(self.closest_distance, distance) + if self.closest_distance is not None + else distance + ) + self.furthest_distance = ( + max(self.furthest_distance, distance) + if self.furthest_distance is not None + else distance + ) + + def accept_result(self, doc: str, distance: float, meta: dict): if doc is None: return - - self.accepted_results.append(MemoryRequestResult(doc=doc, distance=distance, meta=meta)) - + + self.accepted_results.append( + MemoryRequestResult(doc=doc, distance=distance, meta=meta) + ) + @property def closest_text(self): return str(self.results[0].doc) if self.results else None - + + memory_request = contextvars.ContextVar("memory_request", default=None) class MemoryRequest: - - def __init__(self, query:str, query_params:dict=None): + def __init__(self, query: str, query_params: dict = None): self.query = query self.query_params = query_params - + def __enter__(self): - self.state = MemoryRequestState(query=self.query, query_params=self.query_params) + self.state = MemoryRequestState( + query=self.query, query_params=self.query_params + ) self.token = memory_request.set(self.state) self.time_start = time.time() return self.state - + def __exit__(self, *args): - self.time_end = time.time() - + if DEBUG_MEMORY_REQUESTS: max_length = 50 - query = self.state.query[:max_length]+"..." if len(self.state.query) > max_length else self.state.query - log.debug("MemoryRequest", number_of_results=len(self.state.results), query=query) - log.debug("MemoryRequest", number_of_accepted_results=len(self.state.accepted_results), query=query) - + query = ( + self.state.query[:max_length] + "..." + if len(self.state.query) > max_length + else self.state.query + ) + log.debug( + "MemoryRequest", number_of_results=len(self.state.results), query=query + ) + log.debug( + "MemoryRequest", + number_of_accepted_results=len(self.state.accepted_results), + query=query, + ) + for result in self.state.results: # distance to 2 decimal places - log.debug("MemoryRequest RESULT", distance=f"{result.distance:.2f}", doc=result.doc[:max_length]+"...") - + log.debug( + "MemoryRequest RESULT", + distance=f"{result.distance:.2f}", + doc=result.doc[:max_length] + "...", + ) + agent_context = active_agent.get() - - emit("memory_request", data=self.state.model_dump(), meta={ - "agent_stack": agent_context.agent_stack if agent_context else [], - "agent_stack_uid": agent_context.agent_stack_uid if agent_context else None, - "duration": self.time_end - self.time_start, - }, websocket_passthrough=True) - + + emit( + "memory_request", + data=self.state.model_dump(), + meta={ + "agent_stack": agent_context.agent_stack if agent_context else [], + "agent_stack_uid": agent_context.agent_stack_uid + if agent_context + else None, + "duration": self.time_end - self.time_start, + }, + websocket_passthrough=True, + ) + memory_request.reset(self.token) return False - + # decorator that opens a memory request context async def start_memory_request(query): @@ -103,5 +130,7 @@ async def start_memory_request(query): async def wrapper(*args, **kwargs): with MemoryRequest(query): return await fn(*args, **kwargs) + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/src/talemate/agents/memory/exceptions.py b/src/talemate/agents/memory/exceptions.py index ffe6fa17..7d27ce52 100644 --- a/src/talemate/agents/memory/exceptions.py +++ b/src/talemate/agents/memory/exceptions.py @@ -1,18 +1,17 @@ -__all__ = [ - 'EmbeddingsModelLoadError', - 'MemoryAgentError', - 'SetDBError' -] +__all__ = ["EmbeddingsModelLoadError", "MemoryAgentError", "SetDBError"] + class MemoryAgentError(Exception): pass + class SetDBError(OSError, MemoryAgentError): - - def __init__(self, details:str): + def __init__(self, details: str): super().__init__(f"Memory Agent - Failed to set up the database: {details}") + class EmbeddingsModelLoadError(ValueError, MemoryAgentError): - - def __init__(self, model_name:str, details:str): - super().__init__(f"Memory Agent - Failed to load embeddings model {model_name}: {details}") \ No newline at end of file + def __init__(self, model_name: str, details: str): + super().__init__( + f"Memory Agent - Failed to load embeddings model {model_name}: {details}" + ) diff --git a/src/talemate/agents/memory/rag.py b/src/talemate/agents/memory/rag.py index 590ca791..200a161d 100644 --- a/src/talemate/agents/memory/rag.py +++ b/src/talemate/agents/memory/rag.py @@ -4,7 +4,6 @@ from talemate.agents.base import ( AgentAction, AgentActionConfig, ) -from talemate.emit import emit import talemate.instance as instance if TYPE_CHECKING: @@ -14,11 +13,10 @@ __all__ = ["MemoryRAGMixin"] log = structlog.get_logger() + class MemoryRAGMixin: - @classmethod def add_actions(cls, actions: dict[str, AgentAction]): - actions["use_long_term_memory"] = AgentAction( enabled=True, container=True, @@ -44,7 +42,7 @@ class MemoryRAGMixin: { "label": "AI compiled question and answers (slow)", "value": "questions", - } + }, ], ), "number_of_queries": AgentActionConfig( @@ -65,7 +63,7 @@ class MemoryRAGMixin: {"label": "Short (256)", "value": "256"}, {"label": "Medium (512)", "value": "512"}, {"label": "Long (1024)", "value": "1024"}, - ] + ], ), "cache": AgentActionConfig( type="bool", @@ -73,16 +71,16 @@ class MemoryRAGMixin: description="Cache the long term memory for faster retrieval.", note="This is a cross-agent cache, assuming they use the same options.", value=True, - ) + ), }, ) - + # config property helpers - + @property def long_term_memory_enabled(self): return self.actions["use_long_term_memory"].enabled - + @property def long_term_memory_retrieval_method(self): return self.actions["use_long_term_memory"].config["retrieval_method"].value @@ -90,60 +88,60 @@ class MemoryRAGMixin: @property def long_term_memory_number_of_queries(self): return self.actions["use_long_term_memory"].config["number_of_queries"].value - + @property def long_term_memory_answer_length(self): return int(self.actions["use_long_term_memory"].config["answer_length"].value) - + @property def long_term_memory_cache(self): return self.actions["use_long_term_memory"].config["cache"].value - + @property def long_term_memory_cache_key(self): """ Build the key from the various options """ - + parts = [ self.long_term_memory_retrieval_method, self.long_term_memory_number_of_queries, - self.long_term_memory_answer_length + self.long_term_memory_answer_length, ] - + return "-".join(map(str, parts)) - - + def connect(self, scene): super().connect(scene) - + # new scene, reset cache scene.rag_cache = {} - + # methods - - async def rag_set_cache(self, content:list[str]): + + async def rag_set_cache(self, content: list[str]): self.scene.rag_cache[self.long_term_memory_cache_key] = { "content": content, - "fingerprint": self.scene.history[-1].fingerprint if self.scene.history else 0 + "fingerprint": self.scene.history[-1].fingerprint + if self.scene.history + else 0, } - + async def rag_get_cache(self) -> list[str] | None: - if not self.long_term_memory_cache: return None - + fingerprint = self.scene.history[-1].fingerprint if self.scene.history else 0 cache = self.scene.rag_cache.get(self.long_term_memory_cache_key) - + if cache and cache["fingerprint"] == fingerprint: return cache["content"] - + return None - + async def rag_build( - self, - character: "Character | None" = None, + self, + character: "Character | None" = None, prompt: str = "", sub_instruction: str = "", ) -> list[str]: @@ -153,37 +151,41 @@ class MemoryRAGMixin: if not self.long_term_memory_enabled: return [] - + cached = await self.rag_get_cache() - + if cached: - log.debug(f"Using cached long term memory", agent=self.agent_type, key=self.long_term_memory_cache_key) + log.debug( + "Using cached long term memory", + agent=self.agent_type, + key=self.long_term_memory_cache_key, + ) return cached memory_context = "" retrieval_method = self.long_term_memory_retrieval_method - + if not sub_instruction: if character: sub_instruction = f"continue the scene as {character.name}" elif hasattr(self, "rag_build_sub_instruction"): sub_instruction = await self.rag_build_sub_instruction() - + if not sub_instruction: sub_instruction = "continue the scene" - + if retrieval_method != "direct": world_state = instance.get_agent("world_state") - + if not prompt: prompt = self.scene.context_history( keep_director=False, budget=int(self.client.max_token_length * 0.75), ) - + if isinstance(prompt, list): prompt = "\n".join(prompt) - + log.debug( "memory_rag_mixin.build_prompt_default_memory", direct=False, @@ -193,20 +195,21 @@ class MemoryRAGMixin: if retrieval_method == "questions": memory_context = ( await world_state.analyze_text_and_extract_context( - prompt, sub_instruction, + prompt, + sub_instruction, include_character_context=True, response_length=self.long_term_memory_answer_length, - num_queries=self.long_term_memory_number_of_queries + num_queries=self.long_term_memory_number_of_queries, ) ).split("\n") elif retrieval_method == "queries": memory_context = ( await world_state.analyze_text_and_extract_context_via_queries( - prompt, sub_instruction, + prompt, + sub_instruction, include_character_context=True, response_length=self.long_term_memory_answer_length, - num_queries=self.long_term_memory_number_of_queries - + num_queries=self.long_term_memory_number_of_queries, ) ) @@ -223,4 +226,4 @@ class MemoryRAGMixin: await self.rag_set_cache(memory_context) - return memory_context \ No newline at end of file + return memory_context diff --git a/src/talemate/agents/narrator/__init__.py b/src/talemate/agents/narrator/__init__.py index 3a482c16..f8bcbc52 100644 --- a/src/talemate/agents/narrator/__init__.py +++ b/src/talemate/agents/narrator/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses import random from functools import wraps -from inspect import signature from typing import TYPE_CHECKING import structlog @@ -50,15 +49,18 @@ log = structlog.get_logger("talemate.agents.narrator") class NarratorAgentEmission(AgentEmission): generation: list[str] = dataclasses.field(default_factory=list) response: str = dataclasses.field(default="") - dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list) + dynamic_instructions: list[DynamicInstruction] = dataclasses.field( + default_factory=list + ) talemate.emit.async_signals.register( - "agent.narrator.before_generate", + "agent.narrator.before_generate", "agent.narrator.inject_instructions", "agent.narrator.generated", ) + def set_processing(fn): """ Custom decorator that emits the agent status as processing while the function @@ -74,11 +76,15 @@ def set_processing(fn): if self.content_use_writing_style: self.set_context_states(writing_style=self.scene.writing_style) - await talemate.emit.async_signals.get("agent.narrator.before_generate").send(emission) - await talemate.emit.async_signals.get("agent.narrator.inject_instructions").send(emission) - + await talemate.emit.async_signals.get("agent.narrator.before_generate").send( + emission + ) + await talemate.emit.async_signals.get( + "agent.narrator.inject_instructions" + ).send(emission) + agent_context.state["dynamic_instructions"] = emission.dynamic_instructions - + response = await fn(self, *args, **kwargs) emission.response = response await talemate.emit.async_signals.get("agent.narrator.generated").send(emission) @@ -88,10 +94,7 @@ def set_processing(fn): @register() -class NarratorAgent( - MemoryRAGMixin, - Agent -): +class NarratorAgent(MemoryRAGMixin, Agent): """ Handles narration of the story """ @@ -99,7 +102,7 @@ class NarratorAgent( agent_type = "narrator" verbose_name = "Narrator" set_processing = set_processing - + websocket_handler = NarratorWebsocketHandler @classmethod @@ -117,7 +120,7 @@ class NarratorAgent( min=32, max=1024, step=32, - ), + ), "instructions": AgentActionConfig( type="text", label="Instructions", @@ -154,7 +157,7 @@ class NarratorAgent( description="Use the writing style selected in the scene settings", value=True, ), - } + }, ), "narrate_time_passage": AgentAction( enabled=True, @@ -201,7 +204,7 @@ class NarratorAgent( }, ), } - + MemoryRAGMixin.add_actions(actions) return actions @@ -232,28 +235,33 @@ class NarratorAgent( if self.actions["generation_override"].enabled: return self.actions["generation_override"].config["length"].value return 128 - + @property def narrate_time_passage_enabled(self) -> bool: return self.actions["narrate_time_passage"].enabled - + @property def narrate_dialogue_enabled(self) -> bool: return self.actions["narrate_dialogue"].enabled @property - def narrate_dialogue_ai_chance(self) -> float: + def narrate_dialogue_ai_chance(self) -> float: return self.actions["narrate_dialogue"].config["ai_dialog"].value - + @property def narrate_dialogue_player_chance(self) -> float: return self.actions["narrate_dialogue"].config["player_dialog"].value - + @property def content_use_writing_style(self) -> bool: return self.actions["content"].config["use_writing_style"].value - - def clean_result(self, result:str, ensure_dialog_format:bool=True, force_narrative:bool=True) -> str: + + def clean_result( + self, + result: str, + ensure_dialog_format: bool = True, + force_narrative: bool = True, + ) -> str: """ Cleans the result of a narration """ @@ -264,15 +272,14 @@ class NarratorAgent( cleaned = [] for line in result.split("\n"): - # skip lines that start with a # if line.startswith("#"): continue - + log.debug("clean_result", line=line) - + character_dialogue_detected = False - + for character_name in character_names: if not character_name: continue @@ -280,25 +287,24 @@ class NarratorAgent( character_dialogue_detected = True elif line.startswith(f"{character_name.upper()}"): character_dialogue_detected = True - + if character_dialogue_detected: break - + if character_dialogue_detected: break - + cleaned.append(line) result = "\n".join(cleaned) - + result = util.strip_partial_sentences(result) editor = get_agent("editor") - + if ensure_dialog_format or force_narrative: if editor.fix_exposition_enabled and editor.fix_exposition_narrator: result = editor.fix_exposition_in_text(result) - - + return result def connect(self, scene): @@ -324,16 +330,16 @@ class NarratorAgent( event.duration, event.human_duration, event.narrative ) narrator_message = NarratorMessage( - response, - meta = { + response, + meta={ "agent": "narrator", "function": "narrate_time_passage", "arguments": { "duration": event.duration, "time_passed": event.human_duration, "narrative_direction": event.narrative, - } - } + }, + }, ) emit("narrator", narrator_message) self.scene.push_history(narrator_message) @@ -345,7 +351,7 @@ class NarratorAgent( if not self.narrate_dialogue_enabled: return - + if event.game_loop.had_passive_narration: log.debug( "narrate on dialog", @@ -375,14 +381,14 @@ class NarratorAgent( response = await self.narrate_after_dialogue(event.actor.character) narrator_message = NarratorMessage( - response, + response, meta={ "agent": "narrator", "function": "narrate_after_dialogue", "arguments": { "character": event.actor.character.name, - } - } + }, + }, ) emit("narrator", narrator_message) self.scene.push_history(narrator_message) @@ -390,7 +396,7 @@ class NarratorAgent( event.game_loop.had_passive_narration = True @set_processing - @store_context_state('narrative_direction', visual_narration=True) + @store_context_state("narrative_direction", visual_narration=True) async def narrate_scene(self, narrative_direction: str | None = None): """ Narrate the scene @@ -413,13 +419,13 @@ class NarratorAgent( return response @set_processing - @store_context_state('narrative_direction') + @store_context_state("narrative_direction") async def progress_story(self, narrative_direction: str | None = None): """ Narrate scene progression, moving the plot forward. - + Arguments: - + - narrative_direction: A string describing the direction the narrative should take. If not provided, will attempt to subtly move the story forward. """ @@ -431,10 +437,8 @@ class NarratorAgent( if narrative_direction is None: narrative_direction = "Slightly move the current scene forward." - log.debug( - "narrative_direction", narrative_direction=narrative_direction - ) - + log.debug("narrative_direction", narrative_direction=narrative_direction) + response = await Prompt.request( "narrator.narrate-progress", self.client, @@ -453,13 +457,17 @@ class NarratorAgent( log.debug("progress_story", response=response) response = self.clean_result(response.strip()) - + return response @set_processing - @store_context_state('query', query_narration=True) + @store_context_state("query", query_narration=True) async def narrate_query( - self, query: str, at_the_end: bool = False, as_narrative: bool = True, extra_context: str = None + self, + query: str, + at_the_end: bool = False, + as_narrative: bool = True, + extra_context: str = None, ): """ Narrate a specific query @@ -479,20 +487,20 @@ class NarratorAgent( }, ) response = self.clean_result( - response.strip(), - ensure_dialog_format=False, - force_narrative=as_narrative + response.strip(), ensure_dialog_format=False, force_narrative=as_narrative ) return response - + @set_processing - @store_context_state('character', 'narrative_direction', visual_narration=True) - async def narrate_character(self, character:"Character", narrative_direction: str = None): + @store_context_state("character", "narrative_direction", visual_narration=True) + async def narrate_character( + self, character: "Character", narrative_direction: str = None + ): """ Narrate a specific character """ - + response = await Prompt.request( "narrator.narrate-character", self.client, @@ -506,12 +514,14 @@ class NarratorAgent( }, ) - response = self.clean_result(response.strip(), ensure_dialog_format=False, force_narrative=True) + response = self.clean_result( + response.strip(), ensure_dialog_format=False, force_narrative=True + ) return response @set_processing - @store_context_state('narrative_direction', time_narration=True) + @store_context_state("narrative_direction", time_narration=True) async def narrate_time_passage( self, duration: str, time_passed: str, narrative_direction: str ): @@ -528,7 +538,7 @@ class NarratorAgent( "max_tokens": self.client.max_token_length, "duration": duration, "time_passed": time_passed, - "narrative": narrative_direction, # backwards compatibility + "narrative": narrative_direction, # backwards compatibility "narrative_direction": narrative_direction, "extra_instructions": self.extra_instructions, }, @@ -541,7 +551,7 @@ class NarratorAgent( return response @set_processing - @store_context_state('narrative_direction', sensory_narration=True) + @store_context_state("narrative_direction", sensory_narration=True) async def narrate_after_dialogue( self, character: Character, @@ -572,16 +582,16 @@ class NarratorAgent( async def narrate_environment(self, narrative_direction: str = None): """ Narrate the environment - + Wraps narrate_after_dialogue with the player character as the perspective character """ - + pc = self.scene.get_player_character() return await self.narrate_after_dialogue(pc, narrative_direction) @set_processing - @store_context_state('narrative_direction', 'character') + @store_context_state("narrative_direction", "character") async def narrate_character_entry( self, character: Character, narrative_direction: str = None ): @@ -607,11 +617,9 @@ class NarratorAgent( return response @set_processing - @store_context_state('narrative_direction', 'character') + @store_context_state("narrative_direction", "character") async def narrate_character_exit( - self, - character: Character, - narrative_direction: str = None + self, character: Character, narrative_direction: str = None ): """ Narrate a character exiting the scene @@ -679,8 +687,10 @@ class NarratorAgent( later on """ args = parameters.copy() - - if args.get("character") and isinstance(args["character"], self.scene.Character): + + if args.get("character") and isinstance( + args["character"], self.scene.Character + ): args["character"] = args["character"].name return { @@ -701,7 +711,9 @@ class NarratorAgent( fn = getattr(self, action_name) narration = await fn(**kwargs) - narrator_message = NarratorMessage(narration, meta=self.action_to_meta(action_name, kwargs)) + narrator_message = NarratorMessage( + narration, meta=self.action_to_meta(action_name, kwargs) + ) self.scene.push_history(narrator_message) if emit_message: @@ -720,20 +732,22 @@ class NarratorAgent( kind=kind, agent_function_name=agent_function_name, ) - + # depending on conversation format in the context, stopping strings # for character names may change format conversation_agent = get_agent("conversation") - + if conversation_agent.conversation_format == "movie_script": - character_names = [f"\n{c.name.upper()}\n" for c in self.scene.get_characters()] - else: + character_names = [ + f"\n{c.name.upper()}\n" for c in self.scene.get_characters() + ] + else: character_names = [f"\n{c.name}:" for c in self.scene.get_characters()] - + if prompt_param.get("extra_stopping_strings") is None: prompt_param["extra_stopping_strings"] = [] prompt_param["extra_stopping_strings"] += character_names - + self.set_generation_overrides(prompt_param) def allow_repetition_break( @@ -748,7 +762,9 @@ class NarratorAgent( if not self.actions["generation_override"].enabled: return - prompt_param["max_tokens"] = min(prompt_param.get("max_tokens", 256), self.max_generation_length) + prompt_param["max_tokens"] = min( + prompt_param.get("max_tokens", 256), self.max_generation_length + ) if self.jiggle > 0.0: nuke_repetition = client_context_attribute("nuke_repetition") @@ -756,4 +772,4 @@ class NarratorAgent( # we only apply the agent override if some other mechanism isn't already # setting the nuke_repetition value nuke_repetition = self.jiggle - set_client_context_attribute("nuke_repetition", nuke_repetition) \ No newline at end of file + set_client_context_attribute("nuke_repetition", nuke_repetition) diff --git a/src/talemate/agents/narrator/nodes.py b/src/talemate/agents/narrator/nodes.py index c4306813..36d954c3 100644 --- a/src/talemate/agents/narrator/nodes.py +++ b/src/talemate/agents/narrator/nodes.py @@ -8,15 +8,16 @@ from talemate.util import iso8601_duration_to_human log = structlog.get_logger("talemate.game.engine.nodes.agents.narrator") + class GenerateNarrationBase(AgentNode): """ Generate a narration message """ - - _agent_name:ClassVar[str] = "narrator" - _action_name:ClassVar[str] = "" - _title:ClassVar[str] = "Generate Narration" - + + _agent_name: ClassVar[str] = "narrator" + _action_name: ClassVar[str] = "" + _title: ClassVar[str] = "Generate Narration" + class Fields: narrative_direction = PropertyField( name="narrative_direction", @@ -24,151 +25,170 @@ class GenerateNarrationBase(AgentNode): default="", type="str", ) - + def __init__(self, **kwargs): - if "title" not in kwargs: kwargs["title"] = self._title - + super().__init__(**kwargs) - + def setup(self): self.add_input("state") self.add_input("narrative_direction", socket_type="str", optional=True) - + self.add_output("generated", socket_type="str") self.add_output("message", socket_type="message_object") - + async def prepare_input_values(self) -> dict: input_values = self.get_input_values() input_values.pop("state", None) return input_values - + async def run(self, state: GraphState): input_values = await self.prepare_input_values() try: agent_fn = getattr(self.agent, self._action_name) except AttributeError: - raise InputValueError(self, "_action_name", f"Agent does not have a function named {self._action_name}") - + raise InputValueError( + self, + "_action_name", + f"Agent does not have a function named {self._action_name}", + ) + narration = await agent_fn(**input_values) - + message = NarratorMessage( message=narration, meta=self.agent.action_to_meta(self._action_name, input_values), ) - - self.set_output_values({ - "generated": narration, - "message": message - }) - - + + self.set_output_values({"generated": narration, "message": message}) + + @register("agents/narrator/GenerateProgress") class GenerateProgressNarration(GenerateNarrationBase): """ Generate a progress narration message - """ - _action_name:ClassVar[str] = "progress_story" - _title:ClassVar[str] = "Generate Progress Narration" + """ + + _action_name: ClassVar[str] = "progress_story" + _title: ClassVar[str] = "Generate Progress Narration" @register("agents/narrator/GenerateSceneNarration") class GenerateSceneNarration(GenerateNarrationBase): """ Generate a scene narration message - """ - _action_name:ClassVar[str] = "narrate_scene" - _title:ClassVar[str] = "Generate Scene Narration" + """ -@register("agents/narrator/GenerateAfterDialogNarration") + _action_name: ClassVar[str] = "narrate_scene" + _title: ClassVar[str] = "Generate Scene Narration" + + +@register("agents/narrator/GenerateAfterDialogNarration") class GenerateAfterDialogNarration(GenerateNarrationBase): """ Generate an after dialog narration message - """ - _action_name:ClassVar[str] = "narrate_after_dialogue" - _title:ClassVar[str] = "Generate After Dialog Narration" - + """ + + _action_name: ClassVar[str] = "narrate_after_dialogue" + _title: ClassVar[str] = "Generate After Dialog Narration" + def setup(self): super().setup() self.add_input("character", socket_type="character") + @register("agents/narrator/GenerateEnvironmentNarration") class GenerateEnvironmentNarration(GenerateNarrationBase): """ Generate an environment narration message - """ - _action_name:ClassVar[str] = "narrate_environment" - _title:ClassVar[str] = "Generate Environment Narration" - + """ + + _action_name: ClassVar[str] = "narrate_environment" + _title: ClassVar[str] = "Generate Environment Narration" + + @register("agents/narrator/GenerateQueryNarration") class GenerateQueryNarration(GenerateNarrationBase): """ Generate a query narration message - """ - _action_name:ClassVar[str] = "narrate_query" - _title:ClassVar[str] = "Generate Query Narration" - + """ + + _action_name: ClassVar[str] = "narrate_query" + _title: ClassVar[str] = "Generate Query Narration" + def setup(self): super().setup() self.add_input("query", socket_type="str") self.add_input("extra_context", socket_type="str", optional=True) self.remove_input("narrative_direction") - + + @register("agents/narrator/GenerateCharacterNarration") class GenerateCharacterNarration(GenerateNarrationBase): """ Generate a character narration message - """ - _action_name:ClassVar[str] = "narrate_character" - _title:ClassVar[str] = "Generate Character Narration" - + """ + + _action_name: ClassVar[str] = "narrate_character" + _title: ClassVar[str] = "Generate Character Narration" + def setup(self): super().setup() self.add_input("character", socket_type="character") - + + @register("agents/narrator/GenerateTimeNarration") class GenerateTimeNarration(GenerateNarrationBase): """ Generate a time narration message - """ - _action_name:ClassVar[str] = "narrate_time_passage" - _title:ClassVar[str] = "Generate Time Narration" - + """ + + _action_name: ClassVar[str] = "narrate_time_passage" + _title: ClassVar[str] = "Generate Time Narration" + def setup(self): super().setup() self.add_input("duration", socket_type="str") self.set_property("duration", "P0T1S") - + async def prepare_input_values(self) -> dict: input_values = await super().prepare_input_values() - input_values["time_passed"] = iso8601_duration_to_human(input_values["duration"]) + input_values["time_passed"] = iso8601_duration_to_human( + input_values["duration"] + ) return input_values - + + @register("agents/narrator/GenerateCharacterEntryNarration") class GenerateCharacterEntryNarration(GenerateNarrationBase): """ Generate a character entry narration message - """ - _action_name:ClassVar[str] = "narrate_character_entry" - _title:ClassVar[str] = "Generate Character Entry Narration" - + """ + + _action_name: ClassVar[str] = "narrate_character_entry" + _title: ClassVar[str] = "Generate Character Entry Narration" + def setup(self): super().setup() self.add_input("character", socket_type="character") - + + @register("agents/narrator/GenerateCharacterExitNarration") class GenerateCharacterExitNarration(GenerateNarrationBase): """ Generate a character exit narration message - """ - _action_name:ClassVar[str] = "narrate_character_exit" - _title:ClassVar[str] = "Generate Character Exit Narration" - + """ + + _action_name: ClassVar[str] = "narrate_character_exit" + _title: ClassVar[str] = "Generate Character Exit Narration" + def setup(self): super().setup() self.add_input("character", socket_type="character") - + + @register("agents/narrator/UnpackSource") class UnpackSource(AgentNode): """ @@ -176,25 +196,19 @@ class UnpackSource(AgentNode): into action name and arguments DEPRECATED """ - - _agent_name:ClassVar[str] = "narrator" - + + _agent_name: ClassVar[str] = "narrator" + def __init__(self, title="Unpack Source", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("source", socket_type="str") self.add_output("action_name", socket_type="str") self.add_output("arguments", socket_type="dict") - + async def run(self, state: GraphState): - source = self.get_input_value("source") action_name = "" arguments = {} - - self.set_output_values({ - "action_name": action_name, - "arguments": arguments - }) - - \ No newline at end of file + + self.set_output_values({"action_name": action_name, "arguments": arguments}) diff --git a/src/talemate/agents/narrator/websocket_handler.py b/src/talemate/agents/narrator/websocket_handler.py index 54f6b69e..a1753131 100644 --- a/src/talemate/agents/narrator/websocket_handler.py +++ b/src/talemate/agents/narrator/websocket_handler.py @@ -15,27 +15,31 @@ __all__ = [ log = structlog.get_logger("talemate.server.narrator") + class QueryPayload(pydantic.BaseModel): - query:str - at_the_end:bool=True - + query: str + at_the_end: bool = True + + class NarrativeDirectionPayload(pydantic.BaseModel): - narrative_direction:str = "" + narrative_direction: str = "" + class CharacterPayload(NarrativeDirectionPayload): - character:str = "" + character: str = "" + class NarratorWebsocketHandler(Plugin): """ Handles narrator actions """ - + router = "narrator" - + @property def narrator(self): return get_agent("narrator") - + @set_loading("Progressing the story", cancellable=True, as_async=True) async def handle_progress(self, data: dict): """ @@ -47,7 +51,7 @@ class NarratorWebsocketHandler(Plugin): narrative_direction=payload.narrative_direction, emit_message=True, ) - + @set_loading("Narrating the environment", cancellable=True, as_async=True) async def handle_narrate_environment(self, data: dict): """ @@ -59,8 +63,7 @@ class NarratorWebsocketHandler(Plugin): narrative_direction=payload.narrative_direction, emit_message=True, ) - - + @set_loading("Working on a query", cancellable=True, as_async=True) async def handle_query(self, data: dict): """ @@ -68,56 +71,55 @@ class NarratorWebsocketHandler(Plugin): message. """ payload = QueryPayload(**data) - + narration = await self.narrator.narrate_query(**payload.model_dump()) message: ContextInvestigationMessage = ContextInvestigationMessage( - narration, sub_type="query" + narration, sub_type="query" ) message.set_source("narrator", "narrate_query", **payload.model_dump()) - - + emit("context_investigation", message=message) self.scene.push_history(message) - + @set_loading("Looking at the scene", cancellable=True, as_async=True) async def handle_look_at_scene(self, data: dict): """ Look at the scene (optionally to a specific direction) - + This will result in a context investigation message. """ payload = NarrativeDirectionPayload(**data) - - narration = await self.narrator.narrate_scene(narrative_direction=payload.narrative_direction) - + + narration = await self.narrator.narrate_scene( + narrative_direction=payload.narrative_direction + ) + message: ContextInvestigationMessage = ContextInvestigationMessage( narration, sub_type="visual-scene" ) message.set_source("narrator", "narrate_scene", **payload.model_dump()) - + emit("context_investigation", message=message) self.scene.push_history(message) - + @set_loading("Looking at a character", cancellable=True, as_async=True) async def handle_look_at_character(self, data: dict): """ Look at a character (optionally to a specific direction) - + This will result in a context investigation message. """ payload = CharacterPayload(**data) - - + narration = await self.narrator.narrate_character( - character = self.scene.get_character(payload.character), + character=self.scene.get_character(payload.character), narrative_direction=payload.narrative_direction, ) - + message: ContextInvestigationMessage = ContextInvestigationMessage( narration, sub_type="visual-character" ) message.set_source("narrator", "narrate_character", **payload.model_dump()) - + emit("context_investigation", message=message) self.scene.push_history(message) - \ No newline at end of file diff --git a/src/talemate/agents/registry.py b/src/talemate/agents/registry.py index b195a0b6..a2e56ec8 100644 --- a/src/talemate/agents/registry.py +++ b/src/talemate/agents/registry.py @@ -24,4 +24,4 @@ def get_agent_class(name): def get_agent_types() -> list[str]: - return list(AGENT_CLASSES.keys()) \ No newline at end of file + return list(AGENT_CLASSES.keys()) diff --git a/src/talemate/agents/summarize/__init__.py b/src/talemate/agents/summarize/__init__.py index 68851e0c..f111e475 100644 --- a/src/talemate/agents/summarize/__init__.py +++ b/src/talemate/agents/summarize/__init__.py @@ -7,26 +7,21 @@ import structlog from typing import TYPE_CHECKING, Literal import talemate.emit.async_signals import talemate.util as util -from talemate.emit import emit from talemate.events import GameLoopEvent from talemate.prompts import Prompt from talemate.scene_message import ( - DirectorMessage, - TimePassageMessage, - ContextInvestigationMessage, + DirectorMessage, + TimePassageMessage, + ContextInvestigationMessage, ReinforcementMessage, ) from talemate.world_state.templates import GenerationOptions -from talemate.instance import get_agent -from talemate.exceptions import GenerationCancelled -import talemate.game.focal as focal -import talemate.emit.async_signals from talemate.agents.base import ( - Agent, - AgentAction, - AgentActionConfig, - set_processing, + Agent, + AgentAction, + AgentActionConfig, + set_processing, AgentEmission, AgentTemplateEmission, RagBuildSubInstructionEmission, @@ -53,10 +48,12 @@ talemate.emit.async_signals.register( "agent.summarization.summarize.after", ) + @dataclasses.dataclass class BuildArchiveEmission(AgentEmission): generation_options: GenerationOptions | None = None - + + @dataclasses.dataclass class SummarizeEmission(AgentTemplateEmission): text: str = "" @@ -66,6 +63,7 @@ class SummarizeEmission(AgentTemplateEmission): summarization_history: list[str] | None = None summarization_type: Literal["dialogue", "events"] = "dialogue" + @register() class SummarizeAgent( MemoryRAGMixin, @@ -73,7 +71,7 @@ class SummarizeAgent( ContextInvestigationMixin, # Needs to be after ContextInvestigationMixin so signals are connected in the right order SceneAnalyzationMixin, - Agent + Agent, ): """ An agent that can be used to summarize text @@ -82,7 +80,7 @@ class SummarizeAgent( agent_type = "summarizer" verbose_name = "Summarizer" auto_squish = False - + @classmethod def init_actions(cls) -> dict[str, AgentAction]: actions = { @@ -148,11 +146,11 @@ class SummarizeAgent( @property def archive_threshold(self): return self.actions["archive"].config["threshold"].value - + @property def archive_method(self): return self.actions["archive"].config["method"].value - + @property def archive_include_previous(self): return self.actions["archive"].config["include_previous"].value @@ -179,7 +177,7 @@ class SummarizeAgent( return result # RAG HELPERS - + async def rag_build_sub_instruction(self): # Fire event to get the sub instruction from mixins emission = RagBuildSubInstructionEmission( @@ -188,37 +186,42 @@ class SummarizeAgent( await talemate.emit.async_signals.get( "agent.summarization.rag_build_sub_instruction" ).send(emission) - + return emission.sub_instruction - # SUMMARIZATION HELPERS - + async def previous_summaries(self, entry: ArchiveEntry) -> list[str]: - num_previous = self.archive_include_previous - + # find entry by .id - entry_index = next((i for i, e in enumerate(self.scene.archived_history) if e["id"] == entry.id), None) + entry_index = next( + ( + i + for i, e in enumerate(self.scene.archived_history) + if e["id"] == entry.id + ), + None, + ) if entry_index is None: raise ValueError("Entry not found") end = entry_index - 1 previous_summaries = [] - + if entry and num_previous > 0: if self.layered_history_available: previous_summaries = self.compile_layered_history( - include_base_layer=True, - base_layer_end_id=entry.id + include_base_layer=True, base_layer_end_id=entry.id )[-num_previous:] else: previous_summaries = [ - entry.text for entry in self.scene.archived_history[end-num_previous:end] + entry.text + for entry in self.scene.archived_history[end - num_previous : end] ] return previous_summaries - + # SUMMARIZE @set_processing @@ -226,13 +229,15 @@ class SummarizeAgent( self, scene, generation_options: GenerationOptions | None = None ): end = None - + emission = BuildArchiveEmission( agent=self, generation_options=generation_options, ) - - await talemate.emit.async_signals.get("agent.summarization.before_build_archive").send(emission) + + await talemate.emit.async_signals.get( + "agent.summarization.before_build_archive" + ).send(emission) if not self.actions["archive"].enabled: return @@ -260,7 +265,7 @@ class SummarizeAgent( extra_context = [ entry["text"] for entry in scene.archived_history[-num_previous:] ] - + else: extra_context = None @@ -283,7 +288,10 @@ class SummarizeAgent( # log.debug("build_archive", idx=i, content=str(dialogue)[:64]+"...") - if isinstance(dialogue, (DirectorMessage, ContextInvestigationMessage, ReinforcementMessage)): + if isinstance( + dialogue, + (DirectorMessage, ContextInvestigationMessage, ReinforcementMessage), + ): # these messages are not part of the dialogue and should not be summarized if i == start: start += 1 @@ -292,7 +300,7 @@ class SummarizeAgent( if isinstance(dialogue, TimePassageMessage): log.debug("build_archive", time_passage_message=dialogue) ts = util.iso8601_add(ts, dialogue.ts) - + if i == start: log.debug( "build_archive", @@ -347,23 +355,28 @@ class SummarizeAgent( if str(line) in terminating_line: break adjusted_dialogue.append(line) - + # if difference start and end is less than 4, ignore the termination if len(adjusted_dialogue) > 4: dialogue_entries = adjusted_dialogue end = start + len(dialogue_entries) - 1 else: - log.debug("build_archive", message="Ignoring termination", start=start, end=end, adjusted_dialogue=adjusted_dialogue) + log.debug( + "build_archive", + message="Ignoring termination", + start=start, + end=end, + adjusted_dialogue=adjusted_dialogue, + ) if dialogue_entries: - if not extra_context: # prepend scene intro to dialogue dialogue_entries.insert(0, scene.intro) - + summarized = None retries = 5 - + while not summarized and retries > 0: summarized = await self.summarize( "\n".join(map(str, dialogue_entries)), @@ -371,7 +384,7 @@ class SummarizeAgent( generation_options=generation_options, ) retries -= 1 - + if not summarized: raise IOError("Failed to summarize dialogue", dialogue=dialogue_entries) @@ -382,13 +395,17 @@ class SummarizeAgent( # determine the appropariate timestamp for the summarization - await scene.push_archive(ArchiveEntry(text=summarized, start=start, end=end, ts=ts)) - - scene.ts=ts + await scene.push_archive( + ArchiveEntry(text=summarized, start=start, end=end, ts=ts) + ) + + scene.ts = ts scene.emit_status() - - await talemate.emit.async_signals.get("agent.summarization.after_build_archive").send(emission) - + + await talemate.emit.async_signals.get( + "agent.summarization.after_build_archive" + ).send(emission) + return True @set_processing @@ -408,23 +425,23 @@ class SummarizeAgent( return response @set_processing - async def find_natural_scene_termination(self, event_chunks:list[str]) -> list[list[str]]: + async def find_natural_scene_termination( + self, event_chunks: list[str] + ) -> list[list[str]]: """ Will analyze a list of events and return a list of events that has been separated at a natural scene termination points. """ - + # scan through event chunks and split into paragraphs rebuilt_chunks = [] - + for chunk in event_chunks: - paragraphs = [ - p.strip() for p in chunk.split("\n") if p.strip() - ] + paragraphs = [p.strip() for p in chunk.split("\n") if p.strip()] rebuilt_chunks.extend(paragraphs) - + event_chunks = rebuilt_chunks - + response = await Prompt.request( "summarizer.find-natural-scene-termination-events", self.client, @@ -436,38 +453,42 @@ class SummarizeAgent( }, ) response = response.strip() - + items = util.extract_list(response) - - # will be a list of + + # will be a list of # ["Progress 1", "Progress 12", "Progress 323", ...] - # convert to a list of just numbers - + # convert to a list of just numbers + numbers = [] - + for item in items: match = re.match(r"Progress (\d+)", item.strip()) if match: numbers.append(int(match.group(1))) - + # make sure its unique and sorted numbers = sorted(list(set(numbers))) - + result = [] prev_number = 0 for number in numbers: - result.append(event_chunks[prev_number:number+1]) - prev_number = number+1 - - #result = { + result.append(event_chunks[prev_number : number + 1]) + prev_number = number + 1 + + # result = { # "selected": event_chunks[:number+1], # "remaining": event_chunks[number+1:] - #} - - log.debug("find_natural_scene_termination", response=response, result=result, numbers=numbers) - + # } + + log.debug( + "find_natural_scene_termination", + response=response, + result=result, + numbers=numbers, + ) + return result - @set_processing async def summarize( @@ -481,9 +502,9 @@ class SummarizeAgent( """ Summarize the given text """ - + response_length = 1024 - + template_vars = { "dialogue": text, "scene": self.scene, @@ -500,7 +521,7 @@ class SummarizeAgent( "analyze_chunks": self.layered_history_analyze_chunks, "response_length": response_length, } - + emission = SummarizeEmission( agent=self, text=text, @@ -511,43 +532,46 @@ class SummarizeAgent( summarization_history=extra_context or [], summarization_type="dialogue", ) - - await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission) - + + await talemate.emit.async_signals.get( + "agent.summarization.summarize.before" + ).send(emission) + template_vars["dynamic_instructions"] = emission.dynamic_instructions - + response = await Prompt.request( - f"summarizer.summarize-dialogue", + "summarizer.summarize-dialogue", self.client, f"summarize_{response_length}", vars=template_vars, - dedupe_enabled=False + dedupe_enabled=False, ) log.debug( "summarize", dialogue_length=len(text), summarized_length=len(response) ) - + try: summary = response.split("SUMMARY:")[1].strip() except Exception as e: log.error("summarize failed", response=response, exc=e) return "" - + # capitalize first letter try: summary = summary[0].upper() + summary[1:] except IndexError: pass - + emission.response = self.clean_result(summary) - - await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission) - + + await talemate.emit.async_signals.get( + "agent.summarization.summarize.after" + ).send(emission) + summary = emission.response - + return self.clean_result(summary) - @set_processing async def summarize_events( @@ -563,15 +587,14 @@ class SummarizeAgent( """ Summarize the given text """ - + if not extra_context: extra_context = "" - + mentioned_characters: list["Character"] = self.scene.parse_characters_from_text( - text + extra_context, - exclude_active=True + text + extra_context, exclude_active=True ) - + template_vars = { "dialogue": text, "scene": self.scene, @@ -585,7 +608,7 @@ class SummarizeAgent( "response_length": response_length, "mentioned_characters": mentioned_characters, } - + emission = SummarizeEmission( agent=self, text=text, @@ -596,46 +619,62 @@ class SummarizeAgent( summarization_history=[extra_context] if extra_context else [], summarization_type="events", ) - - await talemate.emit.async_signals.get("agent.summarization.summarize.before").send(emission) - + + await talemate.emit.async_signals.get( + "agent.summarization.summarize.before" + ).send(emission) + template_vars["dynamic_instructions"] = emission.dynamic_instructions - + response = await Prompt.request( - f"summarizer.summarize-events", + "summarizer.summarize-events", self.client, f"summarize_{response_length}", vars=template_vars, - dedupe_enabled=False + dedupe_enabled=False, ) - + response = response.strip() response = response.replace('"', "") - + log.debug( - "layered_history_summarize", original_length=len(text), summarized_length=len(response) + "layered_history_summarize", + original_length=len(text), + summarized_length=len(response), ) - + # clean up analyzation (remove analyzation text) if self.layered_history_analyze_chunks: # remove all lines that begin with "ANALYSIS OF CHUNK \d+:" - response = "\n".join([line for line in response.split("\n") if not line.startswith("ANALYSIS OF CHUNK")]) - + response = "\n".join( + [ + line + for line in response.split("\n") + if not line.startswith("ANALYSIS OF CHUNK") + ] + ) + # strip all occurences of "CHUNK \d+: " from the summary response = re.sub(r"(CHUNK|CHAPTER) \d+:\s+", "", response) - + # capitalize first letter try: response = response[0].upper() + response[1:] except IndexError: pass - + emission.response = self.clean_result(response) - - await talemate.emit.async_signals.get("agent.summarization.summarize.after").send(emission) - + + await talemate.emit.async_signals.get( + "agent.summarization.summarize.after" + ).send(emission) + response = emission.response - - log.debug("summarize_events", original_length=len(text), summarized_length=len(response)) - + + log.debug( + "summarize_events", + original_length=len(text), + summarized_length=len(response), + ) + return self.clean_result(response) diff --git a/src/talemate/agents/summarize/analyze_scene.py b/src/talemate/agents/summarize/analyze_scene.py index d1e3e6e8..f9012d8d 100644 --- a/src/talemate/agents/summarize/analyze_scene.py +++ b/src/talemate/agents/summarize/analyze_scene.py @@ -28,13 +28,15 @@ talemate.emit.async_signals.register( "agent.summarization.scene_analysis.after", "agent.summarization.scene_analysis.cached", "agent.summarization.scene_analysis.before_deep_analysis", - "agent.summarization.scene_analysis.after_deep_analysis", + "agent.summarization.scene_analysis.after_deep_analysis", ) + @dataclasses.dataclass class SceneAnalysisEmission(AgentTemplateEmission): analysis_type: str | None = None + @dataclasses.dataclass class SceneAnalysisDeepAnalysisEmission(AgentEmission): analysis: str @@ -42,15 +44,16 @@ class SceneAnalysisDeepAnalysisEmission(AgentEmission): analysis_sub_type: str | None = None max_content_investigations: int = 1 character: "Character" = None - dynamic_instructions: list[DynamicInstruction] = dataclasses.field(default_factory=list) + dynamic_instructions: list[DynamicInstruction] = dataclasses.field( + default_factory=list + ) class SceneAnalyzationMixin: - """ Summarizer agent mixin that provides functionality for scene analyzation. """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): actions["analyze_scene"] = AgentAction( @@ -71,8 +74,8 @@ class SceneAnalyzationMixin: choices=[ {"label": "Short (256)", "value": "256"}, {"label": "Medium (512)", "value": "512"}, - {"label": "Long (1024)", "value": "1024"} - ] + {"label": "Long (1024)", "value": "1024"}, + ], ), "for_conversation": AgentActionConfig( type="bool", @@ -109,196 +112,204 @@ class SceneAnalyzationMixin: value=True, quick_toggle=True, ), - } + }, ) - + # config property helpers - + @property def analyze_scene(self) -> bool: return self.actions["analyze_scene"].enabled - + @property def analysis_length(self) -> int: return int(self.actions["analyze_scene"].config["analysis_length"].value) - + @property def cache_analysis(self) -> bool: return self.actions["analyze_scene"].config["cache_analysis"].value - + @property def deep_analysis(self) -> bool: return self.actions["analyze_scene"].config["deep_analysis"].value - + @property def deep_analysis_max_context_investigations(self) -> int: - return self.actions["analyze_scene"].config["deep_analysis_max_context_investigations"].value - + return ( + self.actions["analyze_scene"] + .config["deep_analysis_max_context_investigations"] + .value + ) + @property def analyze_scene_for_conversation(self) -> bool: return self.actions["analyze_scene"].config["for_conversation"].value - + @property def analyze_scene_for_narration(self) -> bool: return self.actions["analyze_scene"].config["for_narration"].value - + # signal connect - + def connect(self, scene): super().connect(scene) - talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect( - self.on_inject_instructions - ) + talemate.emit.async_signals.get( + "agent.conversation.inject_instructions" + ).connect(self.on_inject_instructions) talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect( self.on_inject_instructions ) - talemate.emit.async_signals.get("agent.summarization.rag_build_sub_instruction").connect( - self.on_rag_build_sub_instruction - ) - talemate.emit.async_signals.get("agent.editor.revision-analysis.before").connect( - self.on_editor_revision_analysis_before - ) - + talemate.emit.async_signals.get( + "agent.summarization.rag_build_sub_instruction" + ).connect(self.on_rag_build_sub_instruction) + talemate.emit.async_signals.get( + "agent.editor.revision-analysis.before" + ).connect(self.on_editor_revision_analysis_before) + async def on_inject_instructions( - self, - emission:ConversationAgentEmission | NarratorAgentEmission, + self, + emission: ConversationAgentEmission | NarratorAgentEmission, ): """ Injects instructions into the conversation. """ - + if isinstance(emission, ConversationAgentEmission): emission_type = "conversation" elif isinstance(emission, NarratorAgentEmission): emission_type = "narration" else: raise ValueError("Invalid emission type.") - + if not self.analyze_scene: return - + analyze_scene_for_type = getattr(self, f"analyze_scene_for_{emission_type}") - + if not analyze_scene_for_type: return - + analysis = None - + # self.set_scene_states and self.get_scene_state to store # cached analysis in scene states - + if self.cache_analysis: analysis = await self.get_cached_analysis(emission_type) if analysis: - await talemate.emit.async_signals.get("agent.summarization.scene_analysis.cached").send( + await talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.cached" + ).send( SceneAnalysisEmission( - agent=self, - analysis_type=emission_type, - response=analysis, + agent=self, + analysis_type=emission_type, + response=analysis, template_vars={ - "character": emission.character if hasattr(emission, "character") else None, + "character": emission.character + if hasattr(emission, "character") + else None, }, - dynamic_instructions=emission.dynamic_instructions + dynamic_instructions=emission.dynamic_instructions, ) ) - + if not analysis and self.analyze_scene: # analyze the scene for the next action analysis = await self.analyze_scene_for_next_action( emission_type, emission.character if hasattr(emission, "character") else None, - self.analysis_length + self.analysis_length, ) - + await self.set_cached_analysis(emission_type, analysis) - + if not analysis: return emission.dynamic_instructions.append( - DynamicInstruction( - title="Scene Analysis", - content=analysis - ) + DynamicInstruction(title="Scene Analysis", content=analysis) ) - - async def on_rag_build_sub_instruction(self, emission:"RagBuildSubInstructionEmission"): + + async def on_rag_build_sub_instruction( + self, emission: "RagBuildSubInstructionEmission" + ): """ Injects the sub instruction into the analysis. """ sub_instruction = await self.analyze_scene_rag_build_sub_instruction() - + if sub_instruction: emission.sub_instruction = sub_instruction - + async def on_editor_revision_analysis_before(self, emission: AgentTemplateEmission): last_analysis = self.get_scene_state("scene_analysis") if last_analysis: - emission.dynamic_instructions.append(DynamicInstruction( - title="Scene Analysis", - content=last_analysis - )) - + emission.dynamic_instructions.append( + DynamicInstruction(title="Scene Analysis", content=last_analysis) + ) + # helpers - - async def get_cached_analysis(self, typ:str) -> str | None: + + async def get_cached_analysis(self, typ: str) -> str | None: """ Returns the cached analysis for the given type. """ - + cached_analysis = self.get_scene_state(f"cached_analysis_{typ}") - + if not cached_analysis: return None - + fingerprint = self.context_fingerpint() - + if cached_analysis.get("fp") == fingerprint: return cached_analysis["guidance"] - + return None - - async def set_cached_analysis(self, typ:str, analysis:str): + + async def set_cached_analysis(self, typ: str, analysis: str): """ Sets the cached analysis for the given type. """ - + fingerprint = self.context_fingerpint() - + self.set_scene_states( - **{f"cached_analysis_{typ}": { - "fp": fingerprint, - "guidance": analysis, - }} + **{ + f"cached_analysis_{typ}": { + "fp": fingerprint, + "guidance": analysis, + } + } ) - - async def analyze_scene_sub_type(self, analysis_type:str) -> str: + + async def analyze_scene_sub_type(self, analysis_type: str) -> str: """ Analyzes the active agent context to figure out the appropriate sub type """ - + fn = getattr(self, f"analyze_scene_{analysis_type}_sub_type", None) - + if fn: return await fn() - + return "" - + async def analyze_scene_narration_sub_type(self) -> str: """ Analyzes the active agent context to figure out the appropriate sub type for narration analysis. (progress, query etc.) """ - + active_agent_context = active_agent.get() - + if not active_agent_context: return "progress" - + state = active_agent_context.state - + if state.get("narrator__query_narration"): return "query" - + if state.get("narrator__sensory_narration"): return "sensory" @@ -306,70 +317,85 @@ class SceneAnalyzationMixin: if state.get("narrator__character"): return "visual-character" return "visual" - + if state.get("narrator__fn_narrate_character_entry"): return "progress-character-entry" - + if state.get("narrator__fn_narrate_character_exit"): return "progress-character-exit" - + return "progress" - + async def analyze_scene_rag_build_sub_instruction(self): """ Analyzes the active agent context to figure out the appropriate sub type for rag build sub instruction. """ - + active_agent_context = active_agent.get() - + if not active_agent_context: return "" - + state = active_agent_context.state - + if state.get("narrator__query_narration"): query = state["narrator__query"] if query.endswith("?"): return "Answer the following question: " + query else: return query - + narrative_direction = state.get("narrator__narrative_direction") - + if state.get("narrator__sensory_narration") and narrative_direction: - return "Collect information that aids in describing the following sensory experience: " + narrative_direction - + return ( + "Collect information that aids in describing the following sensory experience: " + + narrative_direction + ) + if state.get("narrator__visual_narration") and narrative_direction: - return "Collect information that aids in describing the following visual experience: " + narrative_direction - + return ( + "Collect information that aids in describing the following visual experience: " + + narrative_direction + ) + if state.get("narrator__fn_narrate_character_entry") and narrative_direction: - return "Collect information that aids in describing the following character entry: " + narrative_direction - + return ( + "Collect information that aids in describing the following character entry: " + + narrative_direction + ) + if state.get("narrator__fn_narrate_character_exit") and narrative_direction: - return "Collect information that aids in describing the following character exit: " + narrative_direction - + return ( + "Collect information that aids in describing the following character exit: " + + narrative_direction + ) + if state.get("narrator__fn_narrate_progress"): - return "Collect information that aids in progressing the story: " + narrative_direction - + return ( + "Collect information that aids in progressing the story: " + + narrative_direction + ) + return "" - # actions @set_processing - async def analyze_scene_for_next_action(self, typ:str, character:"Character"=None, length:int=1024) -> str: - + async def analyze_scene_for_next_action( + self, typ: str, character: "Character" = None, length: int = 1024 + ) -> str: """ Analyzes the current scene progress and gives a suggestion for the next action. taken by the given actor. """ - + # deep analysis is only available if the scene has a layered history # and context investigation is enabled - deep_analysis = (self.deep_analysis and self.context_investigation_available) + deep_analysis = self.deep_analysis and self.context_investigation_available analysis_sub_type = await self.analyze_scene_sub_type(typ) - + template_vars = { "max_tokens": self.client.max_token_length, "scene": self.scene, @@ -381,58 +407,60 @@ class SceneAnalyzationMixin: "analysis_type": typ, "analysis_sub_type": analysis_sub_type, } - - emission = SceneAnalysisEmission(agent=self, template_vars=template_vars, analysis_type=typ) - - await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before").send( - emission + + emission = SceneAnalysisEmission( + agent=self, template_vars=template_vars, analysis_type=typ ) - + + await talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.before" + ).send(emission) + template_vars["dynamic_instructions"] = emission.dynamic_instructions - + response = await Prompt.request( f"summarizer.analyze-scene-for-next-{typ}", self.client, f"investigate_{length}", vars=template_vars, ) - + response = strip_partial_sentences(response) - + if not response.strip(): return response - + if deep_analysis: - emission = SceneAnalysisDeepAnalysisEmission( - agent=self, + agent=self, analysis=response, analysis_type=typ, analysis_sub_type=analysis_sub_type, character=character, - max_content_investigations=self.deep_analysis_max_context_investigations + max_content_investigations=self.deep_analysis_max_context_investigations, ) - - await talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").send( - emission - ) - - await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after_deep_analysis").send( - emission - ) - - - await talemate.emit.async_signals.get("agent.summarization.scene_analysis.after").send( + + await talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.before_deep_analysis" + ).send(emission) + + await talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.after_deep_analysis" + ).send(emission) + + await talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.after" + ).send( SceneAnalysisEmission( - agent=self, - template_vars=template_vars, - response=response, + agent=self, + template_vars=template_vars, + response=response, analysis_type=typ, - dynamic_instructions=emission.dynamic_instructions + dynamic_instructions=emission.dynamic_instructions, ) ) - + self.set_context_states(scene_analysis=response) self.set_scene_states(scene_analysis=response) - - return response \ No newline at end of file + + return response diff --git a/src/talemate/agents/summarize/context_investigation.py b/src/talemate/agents/summarize/context_investigation.py index 8748d9b0..6dc04455 100644 --- a/src/talemate/agents/summarize/context_investigation.py +++ b/src/talemate/agents/summarize/context_investigation.py @@ -22,14 +22,12 @@ if TYPE_CHECKING: log = structlog.get_logger() - class ContextInvestigationMixin: - """ Summarizer agent mixin that provides functionality for context investigation through the layered history of the scene. """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): actions["context_investigation"] = AgentAction( @@ -52,7 +50,7 @@ class ContextInvestigationMixin: {"label": "Short (256)", "value": "256"}, {"label": "Medium (512)", "value": "512"}, {"label": "Long (1024)", "value": "1024"}, - ] + ], ), "update_method": AgentActionConfig( type="text", @@ -62,57 +60,56 @@ class ContextInvestigationMixin: choices=[ {"label": "Replace", "value": "replace"}, {"label": "Smart Merge", "value": "merge"}, - ] - ) - } + ], + ), + }, ) - + # config property helpers - + @property def context_investigation_enabled(self): return self.actions["context_investigation"].enabled - + @property def context_investigation_available(self): - return ( - self.context_investigation_enabled and - self.layered_history_available - ) - + return self.context_investigation_enabled and self.layered_history_available + @property def context_investigation_answer_length(self) -> int: return int(self.actions["context_investigation"].config["answer_length"].value) - + @property def context_investigation_update_method(self) -> str: return self.actions["context_investigation"].config["update_method"].value - + # signal connect - + def connect(self, scene): super().connect(scene) - talemate.emit.async_signals.get("agent.conversation.inject_instructions").connect( - self.on_inject_context_investigation - ) + talemate.emit.async_signals.get( + "agent.conversation.inject_instructions" + ).connect(self.on_inject_context_investigation) talemate.emit.async_signals.get("agent.narrator.inject_instructions").connect( self.on_inject_context_investigation ) - talemate.emit.async_signals.get("agent.director.guide.inject_instructions").connect( - self.on_inject_context_investigation - ) - talemate.emit.async_signals.get("agent.summarization.scene_analysis.before_deep_analysis").connect( - self.on_summarization_scene_analysis_before_deep_analysis - ) - - async def on_summarization_scene_analysis_before_deep_analysis(self, emission:SceneAnalysisDeepAnalysisEmission): + talemate.emit.async_signals.get( + "agent.director.guide.inject_instructions" + ).connect(self.on_inject_context_investigation) + talemate.emit.async_signals.get( + "agent.summarization.scene_analysis.before_deep_analysis" + ).connect(self.on_summarization_scene_analysis_before_deep_analysis) + + async def on_summarization_scene_analysis_before_deep_analysis( + self, emission: SceneAnalysisDeepAnalysisEmission + ): """ Handles context investigation for deep scene analysis. """ - + if not self.context_investigation_enabled: return - + suggested_investigations = await self.suggest_context_investigations( emission.analysis, emission.analysis_type, @@ -120,67 +117,72 @@ class ContextInvestigationMixin: max_calls=emission.max_content_investigations, character=emission.character, ) - + response = emission.analysis - - ci_calls:list[focal.Call] = await self.request_context_investigations( - suggested_investigations, - max_calls=emission.max_content_investigations + + ci_calls: list[focal.Call] = await self.request_context_investigations( + suggested_investigations, max_calls=emission.max_content_investigations ) - + log.debug("analyze_scene_for_next_action", ci_calls=ci_calls) - + # append call queries and answers to the response ci_text = [] for ci_call in ci_calls: try: ci_text.append(f"{ci_call.arguments['query']}\n{ci_call.result}") - except KeyError as e: - log.error("analyze_scene_for_next_action", error="Missing key in call", ci_call=ci_call) - - context_investigation="\n\n".join(ci_text if ci_text else []) + except KeyError: + log.error( + "analyze_scene_for_next_action", + error="Missing key in call", + ci_call=ci_call, + ) + + context_investigation = "\n\n".join(ci_text if ci_text else []) current_context_investigation = self.get_scene_state("context_investigation") if current_context_investigation and context_investigation: if self.context_investigation_update_method == "merge": context_investigation = await self.update_context_investigation( current_context_investigation, context_investigation, response ) - + self.set_scene_states(context_investigation=context_investigation) self.set_context_states(context_investigation=context_investigation) - - - - async def on_inject_context_investigation(self, emission:ConversationAgentEmission | NarratorAgentEmission): + + async def on_inject_context_investigation( + self, emission: ConversationAgentEmission | NarratorAgentEmission + ): """ Injects context investigation into the conversation. """ - + if not self.context_investigation_enabled: return - + context_investigation = self.get_scene_state("context_investigation") - log.debug("summarizer.on_inject_context_investigation", context_investigation=context_investigation, emission=emission) + log.debug( + "summarizer.on_inject_context_investigation", + context_investigation=context_investigation, + emission=emission, + ) if context_investigation: emission.dynamic_instructions.append( DynamicInstruction( - title="Context Investigation", - content=context_investigation + title="Context Investigation", content=context_investigation ) ) - + # methods - + @set_processing async def suggest_context_investigations( self, - analysis:str, - analysis_type:str, - analysis_sub_type:str="", - max_calls:int=3, - character:"Character"=None, + analysis: str, + analysis_type: str, + analysis_sub_type: str = "", + max_calls: int = 3, + character: "Character" = None, ) -> str: - template_vars = { "max_tokens": self.client.max_token_length, "scene": self.scene, @@ -192,111 +194,119 @@ class ContextInvestigationMixin: "analysis_type": analysis_type, "analysis_sub_type": analysis_sub_type, } - + if not analysis_sub_type: template = f"summarizer.suggest-context-investigations-for-{analysis_type}" else: template = f"summarizer.suggest-context-investigations-for-{analysis_type}-{analysis_sub_type}" - - log.debug("summarizer.suggest_context_investigations", template=template, template_vars=template_vars) - + + log.debug( + "summarizer.suggest_context_investigations", + template=template, + template_vars=template_vars, + ) + response = await Prompt.request( template, self.client, "investigate_512", vars=template_vars, ) - + return response.strip() - @set_processing async def investigate_context( - self, - layer:int, - index:int, - query:str, - analysis:str="", - max_calls:int=3, - pad_entries:int=5, + self, + layer: int, + index: int, + query: str, + analysis: str = "", + max_calls: int = 3, + pad_entries: int = 5, ) -> str: """ Processes a context investigation. - + Arguments: - + - layer: The layer to investigate - index: The index in the layer to investigate - query: The query to investigate - analysis: Scene analysis text - pad_entries: if > 0 will pad the entries with the given number of entries before and after the start and end index """ - - log.debug("summarizer.investigate_context", layer=layer, index=index, query=query) + + log.debug( + "summarizer.investigate_context", layer=layer, index=index, query=query + ) entry = self.scene.layered_history[layer][index] - + layer_to_investigate = layer - 1 - + start = max(entry["start"] - pad_entries, 0) end = entry["end"] + pad_entries + 1 - + if layer_to_investigate == -1: entries = self.scene.archived_history[start:end] else: entries = self.scene.layered_history[layer_to_investigate][start:end] - - async def answer(query:str, instructions:str) -> str: - log.debug("Answering context investigation", query=query, instructions=answer) - + + async def answer(query: str, instructions: str) -> str: + log.debug( + "Answering context investigation", query=query, instructions=answer + ) + world_state = get_agent("world_state") - + return await world_state.analyze_history_and_follow_instructions( entries, f"{query}\n{instructions}", analysis=analysis, - response_length=self.context_investigation_answer_length + response_length=self.context_investigation_answer_length, ) - - async def investigate_context(chapter_number:str, query:str) -> str: + async def investigate_context(chapter_number: str, query: str) -> str: # look for \d.\d in the chapter number, extract as layer and index match = re.match(r"(\d+)\.(\d+)", chapter_number) if not match: - log.error("summarizer.investigate_context", error="Invalid chapter number", chapter_number=chapter_number) + log.error( + "summarizer.investigate_context", + error="Invalid chapter number", + chapter_number=chapter_number, + ) return "" - + layer = int(match.group(1)) index = int(match.group(2)) - - return await self.investigate_context(layer-1, index-1, query, analysis=analysis, max_calls=max_calls) - + + return await self.investigate_context( + layer - 1, index - 1, query, analysis=analysis, max_calls=max_calls + ) async def abort(): log.debug("Aborting context investigation") - + focal_handler: focal.Focal = focal.Focal( self.client, callbacks=[ focal.Callback( name="investigate_context", - arguments = [ + arguments=[ focal.Argument(name="chapter_number", type="str"), - focal.Argument(name="query", type="str") + focal.Argument(name="query", type="str"), ], - fn=investigate_context + fn=investigate_context, ), focal.Callback( name="answer", - arguments = [ + arguments=[ focal.Argument(name="instructions", type="str"), - focal.Argument(name="query", type="str") + focal.Argument(name="query", type="str"), ], - fn=answer + fn=answer, ), - focal.Callback( - name="abort", - fn=abort - ) + focal.Callback(name="abort", fn=abort), ], max_calls=max_calls, scene=self.scene, @@ -307,84 +317,86 @@ class ContextInvestigationMixin: entries=entries, analysis=analysis, ) - + await focal_handler.request( "summarizer.investigate-context", ) - + log.debug("summarizer.investigate_context", calls=focal_handler.state.calls) - - return focal_handler.state.calls + + return focal_handler.state.calls @set_processing async def request_context_investigations( - self, - analysis:str, - max_calls:int=3, + self, + analysis: str, + max_calls: int = 3, ) -> list[focal.Call]: - """ Requests context investigations for the given analysis. """ - + async def abort(): log.debug("Aborting context investigations") - - async def investigate_context(chapter_number:str, query:str) -> str: + + async def investigate_context(chapter_number: str, query: str) -> str: # look for \d.\d in the chapter number, extract as layer and index match = re.match(r"(\d+)\.(\d+)", chapter_number) if not match: - log.error("summarizer.request_context_investigations.investigate_context", error="Invalid chapter number", chapter_number=chapter_number) + log.error( + "summarizer.request_context_investigations.investigate_context", + error="Invalid chapter number", + chapter_number=chapter_number, + ) return "" layer = int(match.group(1)) index = int(match.group(2)) - + num_layers = len(self.scene.layered_history) - - return await self.investigate_context(num_layers - layer, index-1, query, analysis, max_calls=max_calls) - + + return await self.investigate_context( + num_layers - layer, index - 1, query, analysis, max_calls=max_calls + ) + focal_handler: focal.Focal = focal.Focal( self.client, callbacks=[ focal.Callback( name="investigate_context", - arguments = [ + arguments=[ focal.Argument(name="chapter_number", type="str"), - focal.Argument(name="query", type="str") + focal.Argument(name="query", type="str"), ], - fn=investigate_context + fn=investigate_context, ), - focal.Callback( - name="abort", - fn=abort - ) + focal.Callback(name="abort", fn=abort), ], max_calls=max_calls, scene=self.scene, - text=analysis + text=analysis, ) - + await focal_handler.request( "summarizer.request-context-investigation", ) - - log.debug("summarizer.request_context_investigations", calls=focal_handler.state.calls) - - return focal.collect_calls( - focal_handler.state.calls, - nested=True, - filter=lambda c: c.name == "answer" + + log.debug( + "summarizer.request_context_investigations", calls=focal_handler.state.calls ) - - # return focal_handler.state.calls - + + return focal.collect_calls( + focal_handler.state.calls, nested=True, filter=lambda c: c.name == "answer" + ) + + # return focal_handler.state.calls + @set_processing async def update_context_investigation( self, - current_context_investigation:str, - new_context_investigation:str, - analysis:str, + current_context_investigation: str, + new_context_investigation: str, + analysis: str, ): response = await Prompt.request( "summarizer.update-context-investigation", @@ -398,5 +410,5 @@ class ContextInvestigationMixin: "max_tokens": self.client.max_token_length, }, ) - - return response.strip() \ No newline at end of file + + return response.strip() diff --git a/src/talemate/agents/summarize/layered_history.py b/src/talemate/agents/summarize/layered_history.py index 7e5ed4e9..1dd3f5ce 100644 --- a/src/talemate/agents/summarize/layered_history.py +++ b/src/talemate/agents/summarize/layered_history.py @@ -1,5 +1,5 @@ import structlog -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from talemate.agents.base import ( set_processing, AgentAction, @@ -24,11 +24,12 @@ talemate.emit.async_signals.register( "agent.summarization.layered_history.finalize", ) + @dataclasses.dataclass class LayeredHistoryFinalizeEmission(AgentEmission): entry: LayeredArchiveEntry | None = None summarization_history: list[str] = dataclasses.field(default_factory=lambda: []) - + @property def response(self) -> str | None: return self.entry.text if self.entry else None @@ -38,22 +39,23 @@ class LayeredHistoryFinalizeEmission(AgentEmission): if self.entry: self.entry.text = value + class SummaryLongerThanOriginalError(ValueError): - def __init__(self, original_length:int, summarized_length:int): + def __init__(self, original_length: int, summarized_length: int): self.original_length = original_length self.summarized_length = summarized_length - super().__init__(f"Summarized text is longer than original text: {summarized_length} > {original_length}") + super().__init__( + f"Summarized text is longer than original text: {summarized_length} > {original_length}" + ) class LayeredHistoryMixin: - """ Summarizer agent mixin that provides functionality for maintaining a layered history. """ - + @classmethod def add_actions(cls, actions: dict[str, AgentAction]): - actions["layered_history"] = AgentAction( enabled=True, container=True, @@ -80,7 +82,7 @@ class LayeredHistoryMixin: max=5, step=1, value=3, - ), + ), "max_process_tokens": AgentActionConfig( type="number", label="Maximum tokens to process", @@ -116,69 +118,71 @@ class LayeredHistoryMixin: {"label": "Medium (512)", "value": "512"}, {"label": "Long (1024)", "value": "1024"}, {"label": "Exhaustive (2048)", "value": "2048"}, - ] + ], ), }, ) - + # config property helpers - + @property def layered_history_enabled(self): return self.actions["layered_history"].enabled - + @property def layered_history_threshold(self): return self.actions["layered_history"].config["threshold"].value - + @property def layered_history_max_process_tokens(self): return self.actions["layered_history"].config["max_process_tokens"].value - + @property def layered_history_max_layers(self): return self.actions["layered_history"].config["max_layers"].value - + @property def layered_history_chunk_size(self) -> int: return self.actions["layered_history"].config["chunk_size"].value - + @property def layered_history_analyze_chunks(self) -> bool: return self.actions["layered_history"].config["analyze_chunks"].value - + @property def layered_history_response_length(self) -> int: return int(self.actions["layered_history"].config["response_length"].value) - + @property def layered_history_available(self): - return self.layered_history_enabled and self.scene.layered_history and self.scene.layered_history[0] - - + return ( + self.layered_history_enabled + and self.scene.layered_history + and self.scene.layered_history[0] + ) + # signals - + def connect(self, scene): super().connect(scene) - talemate.emit.async_signals.get("agent.summarization.after_build_archive").connect( - self.on_after_build_archive - ) - - - async def on_after_build_archive(self, emission:"BuildArchiveEmission"): + talemate.emit.async_signals.get( + "agent.summarization.after_build_archive" + ).connect(self.on_after_build_archive) + + async def on_after_build_archive(self, emission: "BuildArchiveEmission"): """ After the archive has been built, we will update the layered history. """ - + if self.layered_history_enabled: await self.summarize_to_layered_history( generation_options=emission.generation_options ) - + # helpers - + async def _lh_split_and_summarize_chunks( - self, + self, chunks: list[dict], extra_context: str, generation_options: GenerationOptions | None = None, @@ -189,21 +193,29 @@ class LayeredHistoryMixin: """ summaries = [] current_chunk = chunks.copy() - + while current_chunk: partial_chunk = [] max_process_tokens = self.layered_history_max_process_tokens - + # Build partial chunk up to max_process_tokens - while current_chunk and util.count_tokens("\n\n".join(chunk['text'] for chunk in partial_chunk)) < max_process_tokens: + while ( + current_chunk + and util.count_tokens( + "\n\n".join(chunk["text"] for chunk in partial_chunk) + ) + < max_process_tokens + ): partial_chunk.append(current_chunk.pop(0)) - - text_to_summarize = "\n\n".join(chunk['text'] for chunk in partial_chunk) - - log.debug("_split_and_summarize_chunks", - tokens_in_chunk=util.count_tokens(text_to_summarize), - max_process_tokens=max_process_tokens) - + + text_to_summarize = "\n\n".join(chunk["text"] for chunk in partial_chunk) + + log.debug( + "_split_and_summarize_chunks", + tokens_in_chunk=util.count_tokens(text_to_summarize), + max_process_tokens=max_process_tokens, + ) + summary_text = await self.summarize_events( text_to_summarize, extra_context=extra_context + "\n\n".join(summaries), @@ -213,9 +225,9 @@ class LayeredHistoryMixin: chunk_size=self.layered_history_chunk_size, ) summaries.append(summary_text) - + return summaries - + def _lh_validate_summary_length(self, summaries: list[str], original_length: int): """ Validates that the summarized text is not longer than the original. @@ -224,17 +236,19 @@ class LayeredHistoryMixin: summarized_length = util.count_tokens(summaries) if summarized_length > original_length: raise SummaryLongerThanOriginalError(original_length, summarized_length) - - log.debug("_validate_summary_length", - original_length=original_length, - summarized_length=summarized_length) - + + log.debug( + "_validate_summary_length", + original_length=original_length, + summarized_length=summarized_length, + ) + def _lh_build_extra_context(self, layer_index: int) -> str: """ Builds extra context from compiled layered history for the given layer. """ return "\n\n".join(self.compile_layered_history(layer_index)) - + def _lh_extract_timestamps(self, chunk: list[dict]) -> tuple[str, str, str]: """ Extracts timestamps from a chunk of entries. @@ -242,144 +256,156 @@ class LayeredHistoryMixin: """ if not chunk: return "PT1S", "PT1S", "PT1S" - - ts = chunk[0].get('ts', 'PT1S') - ts_start = chunk[0].get('ts_start', ts) - ts_end = chunk[-1].get('ts_end', chunk[-1].get('ts', ts)) - + + ts = chunk[0].get("ts", "PT1S") + ts_start = chunk[0].get("ts_start", ts) + ts_end = chunk[-1].get("ts_end", chunk[-1].get("ts", ts)) + return ts, ts_start, ts_end - async def _lh_finalize_archive_entry( - self, + self, entry: LayeredArchiveEntry, summarization_history: list[str] | None = None, ) -> LayeredArchiveEntry: """ Finalizes an archive entry by summarizing it and adding it to the layered history. """ - + emission = LayeredHistoryFinalizeEmission( agent=self, entry=entry, summarization_history=summarization_history, ) - - await talemate.emit.async_signals.get("agent.summarization.layered_history.finalize").send(emission) - + + await talemate.emit.async_signals.get( + "agent.summarization.layered_history.finalize" + ).send(emission) + return emission.entry - # methods def compile_layered_history( - self, - for_layer_index:int = None, - as_objects:bool=False, - include_base_layer:bool=False, - max:int = None, + self, + for_layer_index: int = None, + as_objects: bool = False, + include_base_layer: bool = False, + max: int = None, base_layer_end_id: str | None = None, ) -> list[str]: """ Starts at the last layer and compiles the layered history into a single list of events. - + We are iterating backwards, so the last layer will be the most granular. - + Each preceeding layer starts from the end of the the next layer. """ - + layered_history = self.scene.layered_history compiled = [] next_layer_start = None - + len_layered_history = len(layered_history) - + for i in range(len_layered_history - 1, -1, -1): - if for_layer_index is not None: if i < for_layer_index: break - + log.debug("compilelayered history", i=i, next_layer_start=next_layer_start) - + if not layered_history[i]: continue - + entry_num = 1 - - for layered_history_entry in layered_history[i][next_layer_start if next_layer_start is not None else 0:]: - + + for layered_history_entry in layered_history[i][ + next_layer_start if next_layer_start is not None else 0 : + ]: if base_layer_end_id: - contained = entry_contained(self.scene, base_layer_end_id, HistoryEntry( - index=0, - layer=i+1, - **layered_history_entry) + contained = entry_contained( + self.scene, + base_layer_end_id, + HistoryEntry(index=0, layer=i + 1, **layered_history_entry), ) if contained: - log.debug("compile_layered_history", contained=True, base_layer_end_id=base_layer_end_id) + log.debug( + "compile_layered_history", + contained=True, + base_layer_end_id=base_layer_end_id, + ) break - + text = f"{layered_history_entry['text']}" - - if for_layer_index == i and max is not None and max <= layered_history_entry["end"]: + + if ( + for_layer_index == i + and max is not None + and max <= layered_history_entry["end"] + ): break - + if as_objects: - compiled.append({ - "text": text, - "start": layered_history_entry["start"], - "end": layered_history_entry["end"], - "layer": i, - "layer_r": len_layered_history - i, - "ts_start": layered_history_entry["ts_start"], - "index": entry_num, - }) + compiled.append( + { + "text": text, + "start": layered_history_entry["start"], + "end": layered_history_entry["end"], + "layer": i, + "layer_r": len_layered_history - i, + "ts_start": layered_history_entry["ts_start"], + "index": entry_num, + } + ) entry_num += 1 else: compiled.append(text) - + next_layer_start = layered_history_entry["end"] + 1 - + if i == 0 and include_base_layer: # we are are at layered history layer zero and inclusion of base layer (archived history) is requested # so we append the base layer to the compiled list, starting from # index `next_layer_start` - + entry_num = 1 - - for ah in self.scene.archived_history[next_layer_start or 0:]: - + + for ah in self.scene.archived_history[next_layer_start or 0 :]: if base_layer_end_id and ah["id"] == base_layer_end_id: break - + text = f"{ah['text']}" if as_objects: - compiled.append({ - "text": text, - "start": ah["start"], - "end": ah["end"], - "layer": -1, - "layer_r": 1, - "ts": ah["ts"], - "index": entry_num, - }) + compiled.append( + { + "text": text, + "start": ah["start"], + "end": ah["end"], + "layer": -1, + "layer_r": 1, + "ts": ah["ts"], + "index": entry_num, + } + ) entry_num += 1 else: compiled.append(text) - + return compiled - + @set_processing - async def summarize_to_layered_history(self, generation_options: GenerationOptions | None = None): - + async def summarize_to_layered_history( + self, generation_options: GenerationOptions | None = None + ): """ The layered history is a summarized archive with dynamic layers that will get less and less granular as the scene progresses. - + The most granular is still self.scene.archived_history, which holds all the base layer summarizations. - + self.scene.layered_history = [ # first layer after archived_history [ @@ -391,7 +417,7 @@ class LayeredHistoryMixin: }, ... ], - + # second layer [ { @@ -402,29 +428,29 @@ class LayeredHistoryMixin: }, ... ], - + # additional layers ... ] - + The same token threshold as for the base layer will be used for the layers. - + The same summarization function will be used for the layers. - + The next level layer will be generated automatically when the token threshold is reached. """ - + if not self.scene.archived_history: return # No base layer summaries to work with - + token_threshold = self.layered_history_threshold max_layers = self.layered_history_max_layers - if not hasattr(self.scene, 'layered_history'): + if not hasattr(self.scene, "layered_history"): self.scene.layered_history = [] - + layered_history = self.scene.layered_history async def summarize_layer(source_layer, next_layer_index, start_from) -> bool: @@ -432,147 +458,192 @@ class LayeredHistoryMixin: current_tokens = 0 start_index = start_from noop = True - - total_tokens_in_previous_layer = util.count_tokens([ - entry['text'] for entry in source_layer - ]) + + total_tokens_in_previous_layer = util.count_tokens( + [entry["text"] for entry in source_layer] + ) estimated_entries = total_tokens_in_previous_layer // token_threshold for i in range(start_from, len(source_layer)): entry = source_layer[i] - entry_tokens = util.count_tokens(entry['text']) - - log.debug("summarize_to_layered_history", entry=entry["text"][:100]+"...", tokens=entry_tokens, current_layer=next_layer_index-1) - + entry_tokens = util.count_tokens(entry["text"]) + + log.debug( + "summarize_to_layered_history", + entry=entry["text"][:100] + "...", + tokens=entry_tokens, + current_layer=next_layer_index - 1, + ) + if current_tokens + entry_tokens > token_threshold: if current_chunk: - try: # check if the next layer exists next_layer = layered_history[next_layer_index] except IndexError: # create the next layer layered_history.append([]) - log.debug("summarize_to_layered_history", created_layer=next_layer_index) + log.debug( + "summarize_to_layered_history", + created_layer=next_layer_index, + ) next_layer = layered_history[next_layer_index] - - ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk) - + + ts, ts_start, ts_end = self._lh_extract_timestamps( + current_chunk + ) + extra_context = self._lh_build_extra_context(next_layer_index) - text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk)) + text_length = util.count_tokens( + "\n\n".join(chunk["text"] for chunk in current_chunk) + ) num_entries_in_layer = len(layered_history[next_layer_index]) - emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", data={"cancellable": True}) - + emit( + "status", + status="busy", + message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer} / {estimated_entries}", + data={"cancellable": True}, + ) + summaries = await self._lh_split_and_summarize_chunks( current_chunk, extra_context, generation_options=generation_options, ) noop = False - + # validate summary length self._lh_validate_summary_length(summaries, text_length) - - next_layer.append(LayeredArchiveEntry(**{ - "start": start_index, - "end": i, - "ts": ts, - "ts_start": ts_start, - "ts_end": ts_end, - "text": "\n\n".join(summaries), - }).model_dump(exclude_none=True)) - - emit("status", status="busy", message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer+1} / {estimated_entries}") - + + next_layer.append( + LayeredArchiveEntry( + **{ + "start": start_index, + "end": i, + "ts": ts, + "ts_start": ts_start, + "ts_end": ts_end, + "text": "\n\n".join(summaries), + } + ).model_dump(exclude_none=True) + ) + + emit( + "status", + status="busy", + message=f"Updating layered history - layer {next_layer_index} - {num_entries_in_layer + 1} / {estimated_entries}", + ) + current_chunk = [] current_tokens = 0 start_index = i current_chunk.append(entry) current_tokens += entry_tokens - - log.debug("summarize_to_layered_history", tokens=current_tokens, threshold=token_threshold, next_layer=next_layer_index) - + + log.debug( + "summarize_to_layered_history", + tokens=current_tokens, + threshold=token_threshold, + next_layer=next_layer_index, + ) + return not noop - - + # First layer (always the base layer) has_been_updated = False - + try: - if not layered_history: layered_history.append([]) log.debug("summarize_to_layered_history", layer="base", new_layer=True) - has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0) + has_been_updated = await summarize_layer( + self.scene.archived_history, 0, 0 + ) elif layered_history[0]: # determine starting point by checking for `end` in the last entry last_entry = layered_history[0][-1] end = last_entry["end"] log.debug("summarize_to_layered_history", layer="base", start=end) - has_been_updated = await summarize_layer(self.scene.archived_history, 0, end) + has_been_updated = await summarize_layer( + self.scene.archived_history, 0, end + ) else: log.debug("summarize_to_layered_history", layer="base", empty=True) - has_been_updated = await summarize_layer(self.scene.archived_history, 0, 0) - + has_been_updated = await summarize_layer( + self.scene.archived_history, 0, 0 + ) + except SummaryLongerThanOriginalError as exc: log.error("summarize_to_layered_history", error=exc, layer="base") emit("status", status="error", message="Layered history update failed.") return except GenerationCancelled as e: - log.info("Generation cancelled, stopping rebuild of historical layered history") - emit("status", message="Rebuilding of layered history cancelled", status="info") + log.info( + "Generation cancelled, stopping rebuild of historical layered history" + ) + emit( + "status", + message="Rebuilding of layered history cancelled", + status="info", + ) handle_generation_cancelled(e) return - + # process layers async def update_layers() -> bool: noop = True for index in range(0, len(layered_history)): - # check against max layers if index + 1 > max_layers: return False - + try: # check if the next layer exists next_layer = layered_history[index + 1] except IndexError: next_layer = None - + end = next_layer[-1]["end"] if next_layer else 0 - + log.debug("summarize_to_layered_history", layer=index, start=end) - summarized = await summarize_layer(layered_history[index], index + 1, end if end else 0) - + summarized = await summarize_layer( + layered_history[index], index + 1, end if end else 0 + ) + if summarized: noop = False - + return not noop - + try: while await update_layers(): has_been_updated = True if has_been_updated: emit("status", status="success", message="Layered history updated.") - + except SummaryLongerThanOriginalError as exc: log.error("summarize_to_layered_history", error=exc, layer="subsequent") emit("status", status="error", message="Layered history update failed.") return except GenerationCancelled as e: - log.info("Generation cancelled, stopping rebuild of historical layered history") - emit("status", message="Rebuilding of layered history cancelled", status="info") + log.info( + "Generation cancelled, stopping rebuild of historical layered history" + ) + emit( + "status", + message="Rebuilding of layered history cancelled", + status="info", + ) handle_generation_cancelled(e) return - - + async def summarize_entries_to_layered_history( - self, - entries: list[dict], + self, + entries: list[dict], next_layer_index: int, start_index: int, end_index: int, @@ -580,11 +651,11 @@ class LayeredHistoryMixin: ) -> list[LayeredArchiveEntry]: """ Summarizes a list of entries into layered history entries. - + This method is used for regenerating specific history entries by processing their source entries. It chunks the entries based on the token threshold and summarizes each chunk into a LayeredArchiveEntry. - + Args: entries: List of dictionaries containing the text entries to summarize. Each entry should have at least a 'text' field and optionally @@ -597,12 +668,12 @@ class LayeredHistoryMixin: correspond to. generation_options: Optional generation options to pass to the summarization process. - + Returns: List of LayeredArchiveEntry objects containing the summarized text along with timestamp and index information. Currently returns a list with a single entry, but the structure supports multiple entries if needed. - + Notes: - The method respects the layered_history_threshold for chunking - Uses helper methods for timestamp extraction, context building, and @@ -611,63 +682,73 @@ class LayeredHistoryMixin: - The last entry is always included in the final chunk if it doesn't exceed the token threshold """ - + token_threshold = self.layered_history_threshold - + archive_entries = [] summaries = [] current_chunk = [] current_tokens = 0 - + ts = "PT1S" ts_start = "PT1S" ts_end = "PT1S" - - + for entry_index, entry in enumerate(entries): is_last_entry = entry_index == len(entries) - 1 - entry_tokens = util.count_tokens(entry['text']) - - log.debug("summarize_entries_to_layered_history", entry=entry["text"][:100]+"...", entry_tokens=entry_tokens, current_layer=next_layer_index-1, current_tokens=current_tokens) - + entry_tokens = util.count_tokens(entry["text"]) + + log.debug( + "summarize_entries_to_layered_history", + entry=entry["text"][:100] + "...", + entry_tokens=entry_tokens, + current_layer=next_layer_index - 1, + current_tokens=current_tokens, + ) + if current_tokens + entry_tokens > token_threshold or is_last_entry: - if is_last_entry and current_tokens + entry_tokens <= token_threshold: # if we are here because this is the last entry and adding it to # the current chunk would not exceed the token threshold, we will # add it to the current chunk current_chunk.append(entry) - + if current_chunk: ts, ts_start, ts_end = self._lh_extract_timestamps(current_chunk) extra_context = self._lh_build_extra_context(next_layer_index) - text_length = util.count_tokens("\n\n".join(chunk['text'] for chunk in current_chunk)) + text_length = util.count_tokens( + "\n\n".join(chunk["text"] for chunk in current_chunk) + ) summaries = await self._lh_split_and_summarize_chunks( current_chunk, extra_context, generation_options=generation_options, ) - + # validate summary length self._lh_validate_summary_length(summaries, text_length) - - archive_entry = LayeredArchiveEntry(**{ - "start": start_index, - "end": end_index, - "ts": ts, - "ts_start": ts_start, - "ts_end": ts_end, - "text": "\n\n".join(summaries), - }) - - archive_entry = await self._lh_finalize_archive_entry(archive_entry, extra_context.split("\n\n")) - + + archive_entry = LayeredArchiveEntry( + **{ + "start": start_index, + "end": end_index, + "ts": ts, + "ts_start": ts_start, + "ts_end": ts_end, + "text": "\n\n".join(summaries), + } + ) + + archive_entry = await self._lh_finalize_archive_entry( + archive_entry, extra_context.split("\n\n") + ) + archive_entries.append(archive_entry) current_chunk.append(entry) current_tokens += entry_tokens - + return archive_entries diff --git a/src/talemate/agents/tts.py b/src/talemate/agents/tts.py index 715d2e64..dc4575aa 100644 --- a/src/talemate/agents/tts.py +++ b/src/talemate/agents/tts.py @@ -51,11 +51,10 @@ if not TTS: def parse_chunks(text: str) -> list[str]: - """ Takes a string and splits it into chunks based on punctuation. - - In case of an error it will return the original text as a single chunk and + + In case of an error it will return the original text as a single chunk and the error will be logged. """ @@ -278,7 +277,6 @@ class TTSAgent(Agent): @property def agent_details(self): - details = { "api": AgentDetail( icon="mdi-server-outline", @@ -645,7 +643,6 @@ class TTSAgent(Agent): # OPENAI async def _generate_openai(self, text: str, chunk_size: int = 1024): - client = AsyncOpenAI(api_key=self.openai_api_key) model = self.actions["openai"].config["model"].value diff --git a/src/talemate/agents/visual/__init__.py b/src/talemate/agents/visual/__init__.py index 5e6e485c..30c5f07d 100644 --- a/src/talemate/agents/visual/__init__.py +++ b/src/talemate/agents/visual/__init__.py @@ -1,15 +1,13 @@ import asyncio -import traceback import structlog -import talemate.agents.visual.automatic1111 -import talemate.agents.visual.comfyui -import talemate.agents.visual.openai_image +import talemate.agents.visual.automatic1111 # noqa: F401 +import talemate.agents.visual.comfyui # noqa: F401 +import talemate.agents.visual.openai_image # noqa: F401 from talemate.agents.base import ( Agent, AgentAction, - AgentActionConditional, AgentActionConfig, AgentDetail, set_processing, @@ -25,11 +23,11 @@ from talemate.prompts.base import Prompt from .commands import * # noqa from .context import VIS_TYPES, VisualContext, VisualContextState, visual_context from .handlers import HANDLERS -from .schema import RESOLUTION_MAP, RenderSettings -from .style import MAJOR_STYLES, STYLE_MAP, Style, combine_styles +from .schema import RESOLUTION_MAP +from .style import MAJOR_STYLES, STYLE_MAP, Style from .websocket_handler import VisualWebsocketHandler -import talemate.agents.visual.nodes +import talemate.agents.visual.nodes # noqa: F401 __all__ = [ "VisualAgent", @@ -114,7 +112,6 @@ class VisualBase(Agent): label="Process in Background", description="Process renders in the background", ), - "_prompts": AgentAction( enabled=True, container=True, @@ -280,7 +277,6 @@ class VisualBase(Agent): await super().ready_check(task) async def setup_check(self): - if not self.actions["automatic_setup"].enabled: return @@ -289,7 +285,6 @@ class VisualBase(Agent): await getattr(self.client, f"visual_{backend.lower()}_setup")(self) async def apply_config(self, *args, **kwargs): - try: backend = kwargs["actions"]["_config"]["config"]["backend"]["value"] except (KeyError, TypeError): @@ -312,7 +307,6 @@ class VisualBase(Agent): backend_fn = getattr(self, f"{self.backend.lower()}_apply_config", None) if backend_fn: - if not backend_changed and was_disabled and self.enabled: # If the backend has not changed, but the agent was previously disabled # and is now enabled, we need to trigger the backend apply_config function @@ -336,7 +330,6 @@ class VisualBase(Agent): ) def prepare_prompt(self, prompt: str, styles: list[Style] = None) -> Style: - prompt_style = Style() prompt_style.load(prompt) @@ -432,9 +425,8 @@ class VisualBase(Agent): async def generate( self, format: str = "portrait", prompt: str = None, automatic: bool = False ): + context: VisualContextState = visual_context.get() - context:VisualContextState = visual_context.get() - log.debug("visual generate", context=context) if automatic and not self.allow_automatic_generation: @@ -466,7 +458,7 @@ class VisualBase(Agent): thematic_style = self.default_style vis_type_styles = self.vis_type_styles(context.vis_type) - prompt:Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style]) + prompt: Style = self.prepare_prompt(prompt, [vis_type_styles, thematic_style]) if context.vis_type == VIS_TYPES.CHARACTER: prompt.keywords.append("character portrait") @@ -488,12 +480,16 @@ class VisualBase(Agent): format = "portrait" context.format = format - + can_generate_image = self.enabled and self.backend_ready - + if not context.prompt_only and not can_generate_image: - emit("status", "Visual agent is not ready for image generation, will output prompt instead.", status="warning") - + emit( + "status", + "Visual agent is not ready for image generation, will output prompt instead.", + status="warning", + ) + # if prompt_only, we don't need to generate an image # instead we emit a system message with the prompt if context.prompt_only or not can_generate_image: @@ -509,20 +505,24 @@ class VisualBase(Agent): "title": f"Visual Prompt - {context.title}", "display": "tonal", "as_markdown": True, - } + }, ) return - + if not can_generate_image: return - + # Call the backend specific generate function backend = self.backend fn = f"{backend.lower()}_generate" log.info( - "visual generate", backend=backend, prompt=prompt, format=format, context=context + "visual generate", + backend=backend, + prompt=prompt, + format=format, + context=context, ) if not hasattr(self, fn): @@ -538,7 +538,6 @@ class VisualBase(Agent): @set_processing async def generate_environment_prompt(self, instructions: str = None): - with RevisionDisabled(): response = await Prompt.request( "visual.generate-environment-prompt", @@ -556,7 +555,6 @@ class VisualBase(Agent): async def generate_character_prompt( self, character_name: str, instructions: str = None ): - character = self.scene.get_character(character_name) with RevisionDisabled(): diff --git a/src/talemate/agents/visual/automatic1111.py b/src/talemate/agents/visual/automatic1111.py index 22079a3c..bc530f27 100644 --- a/src/talemate/agents/visual/automatic1111.py +++ b/src/talemate/agents/visual/automatic1111.py @@ -1,22 +1,15 @@ -import base64 -import io - import httpx import structlog -from PIL import Image from talemate.agents.base import ( - Agent, AgentAction, AgentActionConditional, AgentActionConfig, - AgentDetail, - set_processing, ) from .handlers import register -from .schema import RenderSettings, Resolution -from .style import STYLE_MAP, Style +from .schema import RenderSettings +from .style import Style log = structlog.get_logger("talemate.agents.visual.automatic1111") @@ -58,9 +51,9 @@ SAMPLING_SCHEDULES = [ SAMPLING_SCHEDULES = sorted(SAMPLING_SCHEDULES, key=lambda x: x["label"]) + @register(backend_name="automatic1111", label="AUTOMATIC1111") class Automatic1111Mixin: - automatic1111_default_render_settings = RenderSettings() EXTEND_ACTIONS = { @@ -139,14 +132,14 @@ class Automatic1111Mixin: @property def automatic1111_sampling_method(self): return self.actions["automatic1111"].config["sampling_method"].value - + @property def automatic1111_schedule_type(self): return self.actions["automatic1111"].config["schedule_type"].value - + @property def automatic1111_cfg(self): - return self.actions["automatic1111"].config["cfg"].value + return self.actions["automatic1111"].config["cfg"].value async def automatic1111_generate(self, prompt: Style, format: str): url = self.api_url @@ -162,14 +155,16 @@ class Automatic1111Mixin: "height": resolution.height, "cfg_scale": self.automatic1111_cfg, "sampler_name": self.automatic1111_sampling_method, - "scheduler": self.automatic1111_schedule_type + "scheduler": self.automatic1111_schedule_type, } log.info("automatic1111_generate", payload=payload, url=url) async with httpx.AsyncClient() as client: response = await client.post( - url=f"{url}/sdapi/v1/txt2img", json=payload, timeout=self.generate_timeout + url=f"{url}/sdapi/v1/txt2img", + json=payload, + timeout=self.generate_timeout, ) r = response.json() diff --git a/src/talemate/agents/visual/comfyui.py b/src/talemate/agents/visual/comfyui.py index da2d9ca0..63be0185 100644 --- a/src/talemate/agents/visual/comfyui.py +++ b/src/talemate/agents/visual/comfyui.py @@ -1,6 +1,5 @@ import asyncio import base64 -import io import json import os import random @@ -10,13 +9,12 @@ import urllib.parse import httpx import pydantic import structlog -from PIL import Image from talemate.agents.base import AgentAction, AgentActionConditional, AgentActionConfig from .handlers import register from .schema import RenderSettings, Resolution -from .style import STYLE_MAP, Style +from .style import Style log = structlog.get_logger("talemate.agents.visual.comfyui") @@ -25,7 +23,6 @@ class Workflow(pydantic.BaseModel): nodes: dict def set_resolution(self, resolution: Resolution): - # will collect all latent image nodes # if there is multiple will look for the one with the # title "Talemate Resolution" @@ -55,7 +52,6 @@ class Workflow(pydantic.BaseModel): latent_image_node["inputs"]["height"] = resolution.height def set_prompt(self, prompt: str, negative_prompt: str = None): - # will collect all CLIPTextEncode nodes # if there is multiple will look for the one with the @@ -79,7 +75,6 @@ class Workflow(pydantic.BaseModel): negative_prompt_node = None for node_id, node in self.nodes.items(): - if node["class_type"] == "CLIPTextEncode": if not positive_prompt_node: positive_prompt_node = node @@ -102,7 +97,6 @@ class Workflow(pydantic.BaseModel): negative_prompt_node["inputs"]["text"] = negative_prompt def set_checkpoint(self, checkpoint: str): - # will collect all CheckpointLoaderSimple nodes # if there is multiple will look for the one with the # title "Talemate Load Checkpoint" @@ -139,7 +133,6 @@ class Workflow(pydantic.BaseModel): @register(backend_name="comfyui", label="ComfyUI") class ComfyUIMixin: - comfyui_default_render_settings = RenderSettings() EXTEND_ACTIONS = { @@ -287,7 +280,9 @@ class ComfyUIMixin: log.info("comfyui_generate", payload=payload, url=url) async with httpx.AsyncClient() as client: - response = await client.post(url=f"{url}/prompt", json=payload, timeout=self.generate_timeout) + response = await client.post( + url=f"{url}/prompt", json=payload, timeout=self.generate_timeout + ) log.info("comfyui_generate", response=response.text) diff --git a/src/talemate/agents/visual/context.py b/src/talemate/agents/visual/context.py index 83f769dd..58fc1b87 100644 --- a/src/talemate/agents/visual/context.py +++ b/src/talemate/agents/visual/context.py @@ -30,7 +30,7 @@ class VisualContextState(pydantic.BaseModel): format: Union[str, None] = None replace: bool = False prompt_only: bool = False - + @property def title(self) -> str: if self.vis_type == VIS_TYPES.ENVIRONMENT: diff --git a/src/talemate/agents/visual/handlers.py b/src/talemate/agents/visual/handlers.py index d7f4ed33..0b86d38b 100644 --- a/src/talemate/agents/visual/handlers.py +++ b/src/talemate/agents/visual/handlers.py @@ -7,7 +7,6 @@ HANDLERS = {} class register: - def __init__(self, backend_name: str, label: str): self.backend_name = backend_name self.label = label diff --git a/src/talemate/agents/visual/nodes.py b/src/talemate/agents/visual/nodes.py index 44373664..66254c1e 100644 --- a/src/talemate/agents/visual/nodes.py +++ b/src/talemate/agents/visual/nodes.py @@ -6,23 +6,27 @@ from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode log = structlog.get_logger("talemate.game.engine.nodes.agents.visual") + @register("agents/visual/Settings") class VisualSettings(AgentSettingsNode): """ Base node to render visual agent settings. """ - _agent_name:ClassVar[str] = "visual" - + + _agent_name: ClassVar[str] = "visual" + def __init__(self, title="Visual Settings", **kwargs): super().__init__(title=title, **kwargs) - + + @register("agents/visual/GenerateCharacterPortrait") class GenerateCharacterPortrait(AgentNode): """ Generates a portrait for a character """ - _agent_name:ClassVar[str] = "visual" - + + _agent_name: ClassVar[str] = "visual" + class Fields: instructions = PropertyField( name="instructions", @@ -30,33 +34,34 @@ class GenerateCharacterPortrait(AgentNode): description="instructions for the portrait", default=UNRESOLVED, ) - + def __init__(self, title="Generate Character Portrait", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character", socket_type="character") self.add_input("instructions", socket_type="str", optional=True) - + self.set_property("instructions", UNRESOLVED) - + self.add_output("state") self.add_output("character", socket_type="character") self.add_output("portrait", socket_type="image") - + async def run(self, state: GraphState): character = self.get_input_value("character") instructions = self.normalized_input_value("instructions") - + portrait = await self.agent.generate_character_portrait( character_name=character.name, instructions=instructions, ) - - self.set_output_values({ - "state": self.get_input_value("state"), - "character": character, - "portrait": portrait - }) - \ No newline at end of file + + self.set_output_values( + { + "state": self.get_input_value("state"), + "character": character, + "portrait": portrait, + } + ) diff --git a/src/talemate/agents/visual/openai_image.py b/src/talemate/agents/visual/openai_image.py index ac951059..c09bfc0d 100644 --- a/src/talemate/agents/visual/openai_image.py +++ b/src/talemate/agents/visual/openai_image.py @@ -1,31 +1,21 @@ -import base64 -import io -from urllib.parse import parse_qs, unquote, urlparse - -import httpx import structlog from openai import AsyncOpenAI -from PIL import Image from talemate.agents.base import ( - Agent, AgentAction, AgentActionConditional, AgentActionConfig, - AgentDetail, - set_processing, ) from .handlers import register from .schema import RenderSettings, Resolution -from .style import STYLE_MAP, Style +from .style import Style log = structlog.get_logger("talemate.agents.visual.openai_image") @register(backend_name="openai_image", label="OpenAI") class OpenAIImageMixin: - openai_image_default_render_settings = RenderSettings() EXTEND_ACTIONS = { diff --git a/src/talemate/agents/visual/style.py b/src/talemate/agents/visual/style.py index 0a26b315..66772535 100644 --- a/src/talemate/agents/visual/style.py +++ b/src/talemate/agents/visual/style.py @@ -83,7 +83,9 @@ class Style(pydantic.BaseModel): # Almost taken straight from some of the fooocus style presets, credit goes to the original author STYLE_MAP["digital_art"] = Style( - keywords="in the style of a digital artwork, masterpiece, best quality, high detail".split(", "), + keywords="in the style of a digital artwork, masterpiece, best quality, high detail".split( + ", " + ), negative_keywords="text, watermark, low quality, blurry, photo".split(", "), ) @@ -95,12 +97,16 @@ STYLE_MAP["concept_art"] = Style( ) STYLE_MAP["ink_illustration"] = Style( - keywords="in the style of ink illustration, painting, masterpiece, best quality".split(", "), + keywords="in the style of ink illustration, painting, masterpiece, best quality".split( + ", " + ), negative_keywords="text, watermark, low quality, blurry, photo".split(", "), ) STYLE_MAP["anime"] = Style( - keywords="in the style of anime, masterpiece, best quality, illustration".split(", "), + keywords="in the style of anime, masterpiece, best quality, illustration".split( + ", " + ), negative_keywords="text, watermark, low quality, blurry, photo, 3d".split(", "), ) diff --git a/src/talemate/agents/visual/websocket_handler.py b/src/talemate/agents/visual/websocket_handler.py index 1162c2e9..8694ffdf 100644 --- a/src/talemate/agents/visual/websocket_handler.py +++ b/src/talemate/agents/visual/websocket_handler.py @@ -56,7 +56,6 @@ class VisualWebsocketHandler(Plugin): scene = self.scene if context and context.character_name: - character = scene.get_character(context.character_name) if not character: diff --git a/src/talemate/agents/world_state/__init__.py b/src/talemate/agents/world_state/__init__.py index 9b8f1edf..acb0f8da 100644 --- a/src/talemate/agents/world_state/__init__.py +++ b/src/talemate/agents/world_state/__init__.py @@ -21,7 +21,13 @@ from talemate.scene_message import ( from talemate.util.response import extract_list -from talemate.agents.base import Agent, AgentAction, AgentActionConfig, AgentEmission, set_processing +from talemate.agents.base import ( + Agent, + AgentAction, + AgentActionConfig, + AgentEmission, + set_processing, +) from talemate.agents.registry import register @@ -57,10 +63,7 @@ class TimePassageEmission(WorldStateAgentEmission): @register() -class WorldStateAgent( - CharacterProgressionMixin, - Agent -): +class WorldStateAgent(CharacterProgressionMixin, Agent): """ An agent that handles world state related tasks. """ @@ -91,7 +94,7 @@ class WorldStateAgent( min=1, max=100, step=1, - ) + ), }, ), "update_reinforcements": AgentAction( @@ -140,7 +143,7 @@ class WorldStateAgent( @property def experimental(self): return True - + @property def initial_update(self): return self.actions["update_world_state"].config["initial"].value @@ -148,7 +151,9 @@ class WorldStateAgent( def connect(self, scene): super().connect(scene) talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop) - talemate.emit.async_signals.get("scene_loop_init_after").connect(self.on_scene_loop_init_after) + talemate.emit.async_signals.get("scene_loop_init_after").connect( + self.on_scene_loop_init_after + ) async def advance_time(self, duration: str, narrative: str = None): """ @@ -183,13 +188,13 @@ class WorldStateAgent( if not self.initial_update: return - + if self.get_scene_state("inital_update_done"): return await self.scene.world_state.request_update() self.set_scene_states(inital_update_done=True) - + async def on_game_loop(self, emission: GameLoopEvent): """ Called when a conversation is generated @@ -339,13 +344,13 @@ class WorldStateAgent( context = await memory_agent.multi_query(queries, iterate=3) - #log.debug( + # log.debug( # "analyze_text_and_extract_context_via_queries", # goal=goal, # text=text, # queries=queries, # context=context, - #) + # ) return context @@ -356,7 +361,6 @@ class WorldStateAgent( instruction: str, short: bool = False, ): - kind = "analyze_freeform_short" if short else "analyze_freeform" response = await Prompt.request( @@ -408,21 +412,20 @@ class WorldStateAgent( ) return response - + @set_processing async def analyze_history_and_follow_instructions( self, entries: list[dict], instructions: str, analysis: str = "", - response_length: int = 512 + response_length: int = 512, ) -> str: - """ Takes a list of archived_history or layered_history entries and follows the instructions to generate a response. """ - + response = await Prompt.request( "world_state.analyze-history-and-follow-instructions", self.client, @@ -436,7 +439,7 @@ class WorldStateAgent( "response_length": response_length, }, ) - + return response.strip() @set_processing @@ -480,7 +483,7 @@ class WorldStateAgent( for line in response.split("\n"): if not line.strip(): continue - if not ":" in line: + if ":" not in line: break name, value = line.split(":", 1) data[name.strip()] = value.strip() @@ -542,24 +545,33 @@ class WorldStateAgent( """ Queries a single re-inforcement """ - + if isinstance(character, self.scene.Character): character = character.name - + message = None idx, reinforcement = await self.scene.world_state.find_reinforcement( question, character ) if not reinforcement: - log.warning(f"Reinforcement not found", question=question, character=character) + log.warning( + "Reinforcement not found", question=question, character=character + ) return message = ReinforcementMessage(message="") - message.set_source("world_state", "update_reinforcement", question=question, character=character) - + message.set_source( + "world_state", + "update_reinforcement", + question=question, + character=character, + ) + if reset and reinforcement.insert == "sequential": - self.scene.pop_history(typ="reinforcement", meta_hash=message.meta_hash, all=True) + self.scene.pop_history( + typ="reinforcement", meta_hash=message.meta_hash, all=True + ) if reinforcement.insert == "sequential": kind = "analyze_freeform_medium_short" @@ -594,7 +606,7 @@ class WorldStateAgent( reinforcement.answer = answer reinforcement.due = reinforcement.interval - + # remove any recent previous reinforcement message with same question # to avoid overloading the near history with reinforcement messages if not reset: @@ -818,4 +830,4 @@ class WorldStateAgent( kwargs=kwargs, error=e, ) - raise \ No newline at end of file + raise diff --git a/src/talemate/agents/world_state/character_progression.py b/src/talemate/agents/world_state/character_progression.py index 99ca77a5..dc4f0be4 100644 --- a/src/talemate/agents/world_state/character_progression.py +++ b/src/talemate/agents/world_state/character_progression.py @@ -1,16 +1,9 @@ from typing import TYPE_CHECKING import structlog -import re -from talemate.agents.base import ( - set_processing, - AgentAction, - AgentActionConfig -) -from talemate.prompts import Prompt +from talemate.agents.base import set_processing, AgentAction, AgentActionConfig from talemate.instance import get_agent from talemate.events import GameLoopEvent from talemate.status import set_loading -from talemate.emit import emit import talemate.emit.async_signals import talemate.game.focal as focal @@ -23,8 +16,8 @@ if TYPE_CHECKING: log = structlog.get_logger() + class CharacterProgressionMixin: - """ World-state manager agent mixin that handles tracking of character progression and proposal of updates to character profiles. @@ -55,13 +48,13 @@ class CharacterProgressionMixin: type="bool", label="Propose as suggestions", description="Propose changes as suggestions that need to be manually accepted.", - value=True + value=True, ), "player_character": AgentActionConfig( type="bool", label="Player character", description="Track the player character's progression.", - value=True + value=True, ), "max_changes": AgentActionConfig( type="number", @@ -70,16 +63,16 @@ class CharacterProgressionMixin: value=1, min=1, max=5, - ) - } + ), + }, ) - - # config property helpers - + + # config property helpers + @property def character_progression_enabled(self) -> bool: return self.actions["character_progression"].enabled - + @property def character_progression_frequency(self) -> int: return self.actions["character_progression"].config["frequency"].value @@ -91,7 +84,7 @@ class CharacterProgressionMixin: @property def character_progression_max_changes(self) -> int: return self.actions["character_progression"].config["max_changes"].value - + @property def character_progression_as_suggestions(self) -> bool: return self.actions["character_progression"].config["as_suggestions"].value @@ -100,8 +93,9 @@ class CharacterProgressionMixin: def connect(self, scene): super().connect(scene) - talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop_track_character_progression) - + talemate.emit.async_signals.get("game_loop").connect( + self.on_game_loop_track_character_progression + ) async def on_game_loop_track_character_progression(self, emission: GameLoopEvent): """ @@ -110,58 +104,69 @@ class CharacterProgressionMixin: if not self.enabled or not self.character_progression_enabled: return - + log.debug("on_game_loop_track_character_progression", scene=self.scene) - - rounds_since_last_check = self.get_scene_state("rounds_since_last_character_progression_check", 0) - + + rounds_since_last_check = self.get_scene_state( + "rounds_since_last_character_progression_check", 0 + ) + if rounds_since_last_check < self.character_progression_frequency: rounds_since_last_check += 1 - self.set_scene_states(rounds_since_last_character_progression_check=rounds_since_last_check) + self.set_scene_states( + rounds_since_last_character_progression_check=rounds_since_last_check + ) return - + self.set_scene_states(rounds_since_last_character_progression_check=0) - + for character in self.scene.characters: - if character.is_player and not self.character_progression_player_character: continue - - calls:list[focal.Call] = await self.determine_character_development(character) - await self.character_progression_process_calls( - character = character, - calls = calls, - as_suggestions = self.character_progression_as_suggestions, + + calls: list[focal.Call] = await self.determine_character_development( + character ) - + await self.character_progression_process_calls( + character=character, + calls=calls, + as_suggestions=self.character_progression_as_suggestions, + ) + # methods - + @set_processing - async def character_progression_process_calls(self, character:"Character", calls:list[focal.Call], as_suggestions:bool=True): - - world_state_manager:WorldStateManager = self.scene.world_state_manager + async def character_progression_process_calls( + self, + character: "Character", + calls: list[focal.Call], + as_suggestions: bool = True, + ): + world_state_manager: WorldStateManager = self.scene.world_state_manager if as_suggestions: await world_state_manager.add_suggestion( Suggestion( name=character.name, type="character", id=f"character-{character.name}", - proposals=calls + proposals=calls, ) ) else: for call in calls: # changes will be applied directly to the character if call.name in ["add_attribute", "update_attribute"]: - await character.set_base_attribute(call.arguments["name"], call.result) + await character.set_base_attribute( + call.arguments["name"], call.result + ) elif call.name == "remove_attribute": await character.set_base_attribute(call.arguments["name"], None) elif call.name == "update_description": await character.set_description(call.result) - + @set_processing async def determine_character_development( - self, + self, character: "Character", generation_options: world_state_templates.GenerationOptions | None = None, instructions: str = None, @@ -169,95 +174,96 @@ class CharacterProgressionMixin: """ Determine character development """ - - log.debug("determine_character_development", character=character, generation_options=generation_options) - + + log.debug( + "determine_character_development", + character=character, + generation_options=generation_options, + ) + creator = get_agent("creator") - + @set_loading("Generating character attribute", cancellable=True) async def add_attribute(name: str, instructions: str) -> str: return await creator.generate_character_attribute( character, - attribute_name = name, - instructions = instructions, - generation_options = generation_options, + attribute_name=name, + instructions=instructions, + generation_options=generation_options, ) - + @set_loading("Generating character attribute", cancellable=True) async def update_attribute(name: str, instructions: str) -> str: return await creator.generate_character_attribute( character, - attribute_name = name, - instructions = instructions, - original = character.base_attributes.get(name), - generation_options = generation_options, + attribute_name=name, + instructions=instructions, + original=character.base_attributes.get(name), + generation_options=generation_options, ) - - async def remove_attribute(name: str, reason:str) -> str: + + async def remove_attribute(name: str, reason: str) -> str: return None - + @set_loading("Generating character description", cancellable=True) async def update_description(instructions: str) -> str: return await creator.generate_character_detail( character, - detail_name = "description", - instructions = instructions, - original = character.description, + detail_name="description", + instructions=instructions, + original=character.description, length=1024, - generation_options = generation_options, + generation_options=generation_options, ) - + focal_handler = focal.Focal( self.client, - # callbacks - callbacks = [ + callbacks=[ focal.Callback( - name = "add_attribute", - arguments = [ + name="add_attribute", + arguments=[ focal.Argument(name="name", type="str"), focal.Argument(name="instructions", type="str"), ], - fn = add_attribute + fn=add_attribute, ), focal.Callback( - name = "update_attribute", - arguments = [ + name="update_attribute", + arguments=[ focal.Argument(name="name", type="str"), focal.Argument(name="instructions", type="str"), ], - fn = update_attribute + fn=update_attribute, ), focal.Callback( - name = "remove_attribute", - arguments = [ + name="remove_attribute", + arguments=[ focal.Argument(name="name", type="str"), focal.Argument(name="reason", type="str"), ], - fn = remove_attribute + fn=remove_attribute, ), focal.Callback( - name = "update_description", - arguments = [ + name="update_description", + arguments=[ focal.Argument(name="instructions", type="str"), ], - fn = update_description, - multiple=False + fn=update_description, + multiple=False, ), ], - - max_calls = self.character_progression_max_changes, - + max_calls=self.character_progression_max_changes, # context - character = character, - scene = self.scene, - instructions = instructions, + character=character, + scene=self.scene, + instructions=instructions, ) - + await focal_handler.request( "world_state.determine-character-development", ) - + log.debug("determine_character_development", calls=focal_handler.state.calls) - - return focal_handler.state.calls \ No newline at end of file + + return focal_handler.state.calls diff --git a/src/talemate/agents/world_state/nodes.py b/src/talemate/agents/world_state/nodes.py index 45ad331c..b59fd39c 100644 --- a/src/talemate/agents/world_state/nodes.py +++ b/src/talemate/agents/world_state/nodes.py @@ -1,7 +1,12 @@ import structlog from typing import ClassVar, TYPE_CHECKING from talemate.context import active_scene -from talemate.game.engine.nodes.core import Node, GraphState, PropertyField, UNRESOLVED, InputValueError, TYPE_CHOICES +from talemate.game.engine.nodes.core import ( + GraphState, + PropertyField, + UNRESOLVED, + TYPE_CHOICES, +) from talemate.game.engine.nodes.registry import register from talemate.game.engine.nodes.agent import AgentSettingsNode, AgentNode from talemate.world_state import InsertionMode @@ -10,140 +15,140 @@ from talemate.world_state.manager import WorldStateManager if TYPE_CHECKING: from talemate.tale_mate import Scene, Character -TYPE_CHOICES.extend([ - "world_state/reinforcement", -]) +TYPE_CHOICES.extend( + [ + "world_state/reinforcement", + ] +) log = structlog.get_logger("talemate.game.engine.nodes.agents.world_state") + @register("agents/world_state/Settings") class WorldstateSettings(AgentSettingsNode): """ Base node to render world_state agent settings. """ - _agent_name:ClassVar[str] = "world_state" - + + _agent_name: ClassVar[str] = "world_state" + def __init__(self, title="Worldstate Settings", **kwargs): super().__init__(title=title, **kwargs) - + + @register("agents/world_state/ExtractCharacterSheet") class ExtractCharacterSheet(AgentNode): """ Attempts to extract an attribute based character sheet from a given context for a specific character. - + Additionally alteration instructions can be given to modify the character's existing sheet. - + Inputs: - + - state: The current state of the graph - character_name: The name of the character to extract the sheet for - context: The context to extract the sheet from - alteration_instructions: Instructions to alter the character's sheet - + Outputs: - + - character_sheet: The extracted character sheet (dict) """ - _agent_name:ClassVar[str] = "world_state" - + + _agent_name: ClassVar[str] = "world_state" + def __init__(self, title="Extract Character Sheet", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character_name", socket_type="str") self.add_input("context", socket_type="str") self.add_input("alteration_instructions", socket_type="str", optional=True) - + self.add_output("character_sheet", socket_type="dict") - + async def run(self, state: GraphState): - context = self.require_input("context") character_name = self.require_input("character_name") alteration_instructions = self.get_input_value("alteration_instructions") - + sheet = await self.agent.extract_character_sheet( - name = character_name, - text = context, - alteration_instructions = alteration_instructions + name=character_name, + text=context, + alteration_instructions=alteration_instructions, ) - - self.set_output_values({ - "character_sheet": sheet - }) - + + self.set_output_values({"character_sheet": sheet}) + + @register("agents/world_state/StateReinforcement") class StateReinforcement(AgentNode): """ Reinforces the a tracked state of a character or the world in general. - + Inputs: - + - state: The current state of the graph - query_or_detail: The query or instruction to reinforce - character: The character to reinforce the state for (optional) - + Properties - + - reset: If the state should be reset - + Outputs: - + - state: graph state - message: state reinforcement message """ - - _agent_name:ClassVar[str] = "world_state" - - class Fields: + _agent_name: ClassVar[str] = "world_state" + + class Fields: query_or_detail = PropertyField( name="query_or_detail", description="Query or detail to reinforce", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + instructions = PropertyField( name="instructions", description="Instructions for the reinforcement", type="text", - default="" + default="", ) - + interval = PropertyField( name="interval", description="Interval for reinforcement", type="int", default=10, min=1, - step=1 + step=1, ) - + insert_method = PropertyField( name="insert_method", description="Method to insert reinforcement", type="str", default="sequential", - choices=[ - mode.value for mode in InsertionMode - ] + choices=[mode.value for mode in InsertionMode], ) - + reset = PropertyField( name="reset", description="If the state should be reset", type="bool", - default=False + default=False, ) - - + def __init__(self, title="State Reinforcement", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("query_or_detail", socket_type="str") @@ -155,20 +160,20 @@ class StateReinforcement(AgentNode): self.set_property("interval", 10) self.set_property("insert_method", "sequential") self.set_property("reset", False) - + self.add_output("state") self.add_output("message", socket_type="message") self.add_output("reinforcement", socket_type="world_state/reinforcement") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() query_or_detail = self.require_input("query_or_detail") character = self.normalized_input_value("character") reset = self.get_property("reset") interval = self.require_number_input("interval") instructions = self.normalized_input_value("instructions") insert_method = self.get_property("insert_method") - + await scene.world_state.add_reinforcement( question=query_or_detail, character=character.name if character else None, @@ -176,142 +181,128 @@ class StateReinforcement(AgentNode): interval=interval, insert=insert_method, ) - + message = await self.agent.update_reinforcement( - question=query_or_detail, - character=character, - reset=reset + question=query_or_detail, character=character, reset=reset ) - - self.set_output_values({ - "state": state, - "message": message - }) + + self.set_output_values({"state": state, "message": message}) + @register("agents/world_state/DeactivateCharacter") class DeactivateCharacter(AgentNode): """ Deactivates a character from the world state. - + Inputs: - + - state: The current state of the graph - character: The character to deactivate - + Outputs: - + - state: The updated state """ - - _agent_name:ClassVar[str] = "world_state" - + + _agent_name: ClassVar[str] = "world_state" + def __init__(self, title="Deactivate Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character", socket_type="character") - + self.add_output("state") - + async def run(self, state: GraphState): - character:"Character" = self.require_input("character") - scene:"Scene" = active_scene.get() - - manager:WorldStateManager = scene.world_state_manager - + character: "Character" = self.require_input("character") + scene: "Scene" = active_scene.get() + + manager: WorldStateManager = scene.world_state_manager + manager.deactivate_character(character.name) - - self.set_output_values({ - "state": state - }) - + + self.set_output_values({"state": state}) + + @register("agents/world_state/EvaluateQuery") class EvaluateQuery(AgentNode): """ Evaluates a query on the world state. - + Inputs: - + - state: The current state of the graph - query: The query to evaluate - context: The context to evaluate the query in - + Outputs: - + - state: The current state - result: The result of the query """ - - _agent_name:ClassVar[str] = "world_state" - + + _agent_name: ClassVar[str] = "world_state" + class Fields: - query = PropertyField( name="query", description="The query to evaluate", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + def __init__(self, title="Evaluate Query", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("query", socket_type="str") self.add_input("context", socket_type="str") - + self.set_property("query", UNRESOLVED) - + self.add_output("state") self.add_output("result", socket_type="bool") - + async def run(self, state: GraphState): query = self.require_input("query") context = self.require_input("context") - - result = await self.agent.answer_query_true_or_false( - query=query, - text=context - ) - - self.set_output_values({ - "state": state, - "result": result - }) - + + result = await self.agent.answer_query_true_or_false(query=query, text=context) + + self.set_output_values({"state": state, "result": result}) + + @register("agents/world_state/RequestWorldState") class RequestWorldState(AgentNode): """ Requests the current world state. - + Inputs: - + - state: The current state of the graph - + Outputs: - + - state: The current state - world_state: The current world state """ - - _agent_name:ClassVar[str] = "world_state" - + + _agent_name: ClassVar[str] = "world_state" + def __init__(self, title="Request World State", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") - + self.add_output("state") self.add_output("world_state", socket_type="dict") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() world_state = await scene.world_state.request_update() - - self.set_output_values({ - "state": state, - "world_state": world_state - }) \ No newline at end of file + + self.set_output_values({"state": state, "world_state": world_state}) diff --git a/src/talemate/character.py b/src/talemate/character.py index c1ef9f1a..9091e80b 100644 --- a/src/talemate/character.py +++ b/src/talemate/character.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Union from talemate.instance import get_agent if TYPE_CHECKING: - from talemate.tale_mate import Actor, Character, Scene + from talemate.tale_mate import Character, Scene __all__ = [ @@ -46,7 +46,6 @@ async def activate_character(scene: "Scene", character: Union[str, "Character"]) if isinstance(character, str): character = scene.get_character(character) - if character.name not in scene.inactive_characters: # already activated return False diff --git a/src/talemate/client/__init__.py b/src/talemate/client/__init__.py index 4875e20e..bb1126be 100644 --- a/src/talemate/client/__init__.py +++ b/src/talemate/client/__init__.py @@ -1,19 +1,17 @@ -import os - -import talemate.client.runpod -from talemate.client.anthropic import AnthropicClient -from talemate.client.base import ClientBase, ClientDisabledError -from talemate.client.cohere import CohereClient -from talemate.client.deepseek import DeepSeekClient -from talemate.client.google import GoogleClient -from talemate.client.groq import GroqClient -from talemate.client.koboldcpp import KoboldCppClient -from talemate.client.lmstudio import LMStudioClient -from talemate.client.mistral import MistralAIClient -from talemate.client.ollama import OllamaClient -from talemate.client.openai import OpenAIClient -from talemate.client.openrouter import OpenRouterClient -from talemate.client.openai_compat import OpenAICompatibleClient -from talemate.client.registry import CLIENT_CLASSES, get_client_class, register -from talemate.client.tabbyapi import TabbyAPIClient -from talemate.client.textgenwebui import TextGeneratorWebuiClient +import talemate.client.runpod # noqa: F401 +from talemate.client.anthropic import AnthropicClient # noqa: F401 +from talemate.client.base import ClientBase, ClientDisabledError # noqa: F401 +from talemate.client.cohere import CohereClient # noqa: F401 +from talemate.client.deepseek import DeepSeekClient # noqa: F401 +from talemate.client.google import GoogleClient # noqa: F401 +from talemate.client.groq import GroqClient # noqa: F401 +from talemate.client.koboldcpp import KoboldCppClient # noqa: F401 +from talemate.client.lmstudio import LMStudioClient # noqa: F401 +from talemate.client.mistral import MistralAIClient # noqa: F401 +from talemate.client.ollama import OllamaClient # noqa: F401 +from talemate.client.openai import OpenAIClient # noqa: F401 +from talemate.client.openrouter import OpenRouterClient # noqa: F401 +from talemate.client.openai_compat import OpenAICompatibleClient # noqa: F401 +from talemate.client.registry import CLIENT_CLASSES, get_client_class, register # noqa: F401 +from talemate.client.tabbyapi import TabbyAPIClient # noqa: F401 +from talemate.client.textgenwebui import TextGeneratorWebuiClient # noqa: F401 diff --git a/src/talemate/client/anthropic.py b/src/talemate/client/anthropic.py index 17cdd660..6b6546c8 100644 --- a/src/talemate/client/anthropic.py +++ b/src/talemate/client/anthropic.py @@ -39,6 +39,7 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel): model: str = "claude-3-5-sonnet-latest" double_coercion: str = None + class ClientConfig(EndpointOverride, BaseClientConfig): pass @@ -118,13 +119,13 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase): self.current_status = status - data={ + data = { "error_action": error_action.model_dump() if error_action else None, "double_coercion": self.double_coercion, "meta": self.Meta().model_dump(), "enabled": self.enabled, } - data.update(self._common_status_data()) + data.update(self._common_status_data()) emit( "client_status", message=self.client_type, @@ -135,7 +136,10 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase): ) def set_client(self, max_token_length: int = None): - if not self.anthropic_api_key and not self.endpoint_override_base_url_configured: + if ( + not self.anthropic_api_key + and not self.endpoint_override_base_url_configured + ): self.client = AsyncAnthropic(api_key="sk-1111") log.error("No anthropic API key set") if self.api_key_status: @@ -175,10 +179,10 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase): if "enabled" in kwargs: self.enabled = bool(kwargs["enabled"]) - + if "double_coercion" in kwargs: self.double_coercion = kwargs["double_coercion"] - + self._reconfigure_common_parameters(**kwargs) self._reconfigure_endpoint_override(**kwargs) @@ -208,17 +212,18 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase): Generates text from the given prompt and parameters. """ - if not self.anthropic_api_key and not self.endpoint_override_base_url_configured: + if ( + not self.anthropic_api_key + and not self.endpoint_override_base_url_configured + ): raise Exception("No anthropic API key set") - + prompt, coercion_prompt = self.split_prompt_for_coercion(prompt) - + system_message = self.get_system_message(kind) - - messages = [ - {"role": "user", "content": prompt.strip()} - ] - + + messages = [{"role": "user", "content": prompt.strip()}] + if coercion_prompt: messages.append({"role": "assistant", "content": coercion_prompt.strip()}) @@ -228,7 +233,7 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase): parameters=parameters, system_message=system_message, ) - + completion_tokens = 0 prompt_tokens = 0 @@ -240,22 +245,20 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase): stream=True, **parameters, ) - + response = "" - + async for event in stream: - if event.type == "content_block_delta": content = event.delta.text response += content self.update_request_tokens(self.count_tokens(content)) - + elif event.type == "message_start": prompt_tokens = event.message.usage.input_tokens - + elif event.type == "message_delta": completion_tokens += event.usage.output_tokens - self._returned_prompt_tokens = prompt_tokens self._returned_response_tokens = completion_tokens @@ -267,5 +270,5 @@ class AnthropicClient(EndpointOverrideMixin, ClientBase): self.log.error("generate error", e=e) emit("status", message="anthropic API: Permission Denied", status="error") return "" - except Exception as e: + except Exception: raise diff --git a/src/talemate/client/base.py b/src/talemate/client/base.py index 55f78098..ff1f1afb 100644 --- a/src/talemate/client/base.py +++ b/src/talemate/client/base.py @@ -42,6 +42,7 @@ STOPPING_STRINGS = ["<|im_end|>", ""] # disable smart quotes until text rendering is refactored REPLACE_SMART_QUOTES = True + class ClientDisabledError(OSError): def __init__(self, client: "ClientBase"): self.client = client @@ -76,6 +77,7 @@ class CommonDefaults(pydantic.BaseModel): data_format: Literal["yaml", "json"] | None = None preset_group: str | None = None + class Defaults(CommonDefaults, pydantic.BaseModel): api_url: str = "http://localhost:5000" max_token_length: int = 8192 @@ -88,6 +90,7 @@ class FieldGroup(pydantic.BaseModel): description: str icon: str = "mdi-cog" + class ExtraField(pydantic.BaseModel): name: str type: str @@ -97,6 +100,7 @@ class ExtraField(pydantic.BaseModel): group: FieldGroup | None = None note: ux_schema.Note | None = None + class ParameterReroute(pydantic.BaseModel): talemate_parameter: str client_parameter: str @@ -117,23 +121,23 @@ class RequestInformation(pydantic.BaseModel): start_time: float = pydantic.Field(default_factory=time.time) end_time: float | None = None tokens: int = 0 - + @pydantic.computed_field(description="Duration") @property def duration(self) -> float: end_time = self.end_time or time.time() return end_time - self.start_time - + @pydantic.computed_field(description="Tokens per second") @property def rate(self) -> float: try: end_time = self.end_time or time.time() return self.tokens / (end_time - self.start_time) - except: + except Exception: pass return 0 - + @pydantic.computed_field(description="Status") @property def status(self) -> str: @@ -145,7 +149,7 @@ class RequestInformation(pydantic.BaseModel): return "in progress" else: return "pending" - + @pydantic.computed_field(description="Age") @property def age(self) -> float: @@ -159,10 +163,12 @@ class ClientEmbeddingsStatus: client: "ClientBase | None" = None embedding_name: str | None = None + async_signals.register( "client.embeddings_available", ) + class ClientBase: api_url: str model_name: str @@ -183,10 +189,10 @@ class ClientBase: rate_limit: int | None = None client_type = "base" request_information: RequestInformation | None = None - - status_request_timeout:int = 2 - - system_prompts:SystemPrompts = SystemPrompts() + + status_request_timeout: int = 2 + + system_prompts: SystemPrompts = SystemPrompts() preset_group: str | None = "" rate_limit_counter: CounterRateLimiter = None @@ -216,7 +222,7 @@ class ClientBase: self.max_token_length = ( int(kwargs["max_token_length"]) if kwargs["max_token_length"] else 8192 ) - + self.set_client(max_token_length=self.max_token_length) def __str__(self): @@ -252,23 +258,23 @@ class ClientBase: "temperature", "max_tokens", ] - + @property def supports_embeddings(self) -> bool: return False - + @property def embeddings_function(self): return None - + @property def embeddings_status(self) -> bool: return getattr(self, "_embeddings_status", False) - + @property def embeddings_model_name(self) -> str | None: return getattr(self, "_embeddings_model_name", None) - + @property def embeddings_url(self) -> str: return None @@ -277,16 +283,16 @@ class ClientBase: def embeddings_identifier(self) -> str: return f"client-api/{self.name}/{self.embeddings_model_name}" - async def destroy(self, config:dict): + async def destroy(self, config: dict): """ This is called before the client is removed from talemate.instance.clients - + Use this to perform any cleanup that is necessary. - + If a subclass overrides this method, it should call super().destroy(config) in the end of the method. """ - + if self.supports_embeddings: self.remove_embeddings(config) @@ -296,25 +302,28 @@ class ClientBase: def set_client(self, **kwargs): self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111") - + def set_embeddings(self): - - log.debug("setting embeddings", client=self.name, supports_embeddings=self.supports_embeddings, embeddings_status=self.embeddings_status) - + log.debug( + "setting embeddings", + client=self.name, + supports_embeddings=self.supports_embeddings, + embeddings_status=self.embeddings_status, + ) + if not self.supports_embeddings or not self.embeddings_status: return - + config = load_config(as_model=True) - + key = self.embeddings_identifier - + if key in config.presets.embeddings: log.debug("embeddings already set", client=self.name, key=key) return config.presets.embeddings[key] - - + log.debug("setting embeddings", client=self.name, key=key) - + config.presets.embeddings[key] = EmbeddingFunctionPreset( embeddings="client-api", client=self.name, @@ -324,10 +333,10 @@ class ClientBase: local=False, custom=True, ) - + save_config(config) - - def remove_embeddings(self, config:dict | None = None): + + def remove_embeddings(self, config: dict | None = None): # remove all embeddings for this client for key, value in list(config["presets"]["embeddings"].items()): if value["client"] == self.name and value["embeddings"] == "client-api": @@ -338,7 +347,9 @@ class ClientBase: if isinstance(system_prompts, dict): self.system_prompts = SystemPrompts(**system_prompts) elif not isinstance(system_prompts, SystemPrompts): - raise ValueError("system_prompts must be a `dict` or `SystemPrompts` instance") + raise ValueError( + "system_prompts must be a `dict` or `SystemPrompts` instance" + ) else: self.system_prompts = system_prompts @@ -376,10 +387,10 @@ class ClientBase: """ if "<|BOT|>" in prompt: _, right = prompt.split("<|BOT|>", 1) - + if self.double_coercion: right = f"{self.double_coercion}\n\n{right}" - + return prompt, right return prompt, None @@ -407,7 +418,7 @@ class ClientBase: if "double_coercion" in kwargs: self.double_coercion = kwargs["double_coercion"] - + self._reconfigure_common_parameters(**kwargs) def _reconfigure_common_parameters(self, **kwargs): @@ -415,12 +426,14 @@ class ClientBase: self.rate_limit = kwargs["rate_limit"] if self.rate_limit: if not self.rate_limit_counter: - self.rate_limit_counter = CounterRateLimiter(rate_per_minute=self.rate_limit) + self.rate_limit_counter = CounterRateLimiter( + rate_per_minute=self.rate_limit + ) else: self.rate_limit_counter.update_rate_limit(self.rate_limit) else: self.rate_limit_counter = None - + if "data_format" in kwargs: self.data_format = kwargs["data_format"] @@ -478,19 +491,21 @@ class ClientBase: - kind: the kind of generation """ - - app_config_system_prompts = client_context_attribute("app_config_system_prompts") - + + app_config_system_prompts = client_context_attribute( + "app_config_system_prompts" + ) + if app_config_system_prompts: self.system_prompts.parent = SystemPrompts(**app_config_system_prompts) - + return self.system_prompts.get(kind, self.decensor_enabled) def emit_status(self, processing: bool = None): """ Sets and emits the client status. """ - + if processing is not None: self.processing = processing @@ -516,7 +531,6 @@ class ClientBase: ) if not has_prompt_template and self.auto_determine_prompt_template: - # only attempt to determine the prompt template once per model and # only if the model does not already have a prompt template @@ -581,15 +595,17 @@ class ClientBase: "supports_embeddings": self.supports_embeddings, "embeddings_status": self.embeddings_status, "embeddings_model_name": self.embeddings_model_name, - "request_information": self.request_information.model_dump() if self.request_information else None, + "request_information": self.request_information.model_dump() + if self.request_information + else None, } - + extra_fields = getattr(self.Meta(), "extra_fields", {}) for field_name in extra_fields.keys(): common_data[field_name] = getattr(self, field_name, None) - + return common_data - + def populate_extra_fields(self, data: dict): """ Updates data with the extra fields from the client's Meta @@ -665,7 +681,7 @@ class ClientBase: agent_context.agent.inject_prompt_paramters( parameters, kind, agent_context.action ) - + if client_context_attribute( "nuke_repetition" ) > 0.0 and self.jiggle_enabled_for(kind): @@ -707,7 +723,6 @@ class ClientBase: del parameters[key] def finalize(self, parameters: dict, prompt: str): - prompt = util.replace_special_tokens(prompt) for finalizer in self.finalizers: @@ -740,22 +755,20 @@ class ClientBase: "status", message="Error during generation (check logs)", status="error" ) return "" - - + def _generate_task(self, prompt: str, parameters: dict, kind: str): """ Creates an asyncio task to generate text from the given prompt and parameters. """ - + return asyncio.create_task(self.generate(prompt, parameters, kind)) - def _poll_interrupt(self): """ Creatates a task that continiously checks active_scene.cancel_requested and will complete the task if it is requested. """ - + async def poll(): while True: scene = active_scene.get() @@ -763,58 +776,58 @@ class ClientBase: break await asyncio.sleep(0.3) return GenerationCancelled("Generation cancelled") - + return asyncio.create_task(poll()) - - async def _cancelable_generate(self, prompt: str, parameters: dict, kind: str) -> str | GenerationCancelled: - + + async def _cancelable_generate( + self, prompt: str, parameters: dict, kind: str + ) -> str | GenerationCancelled: """ Queues the generation task and the poll task to be run concurrently. - + If the poll task completes before the generation task, the generation task will be cancelled. - + If the generation task completes before the poll task, the poll task will be cancelled. """ - + task_poll = self._poll_interrupt() task_generate = self._generate_task(prompt, parameters, kind) - + done, pending = await asyncio.wait( - [task_poll, task_generate], - return_when=asyncio.FIRST_COMPLETED + [task_poll, task_generate], return_when=asyncio.FIRST_COMPLETED ) - + # cancel the remaining task for task in pending: task.cancel() - + # return the result of the completed task return done.pop().result() - + async def abort_generation(self): """ This function can be overwritten to trigger an abortion at the other side of the client. - + So a generation is cancelled here, this can be used to cancel a generation at the other side of the client. """ pass - + def new_request(self): """ Creates a new request information object. """ self.request_information = RequestInformation() - + def end_request(self): """ Ends the request information object. """ self.request_information.end_time = time.time() - + def update_request_tokens(self, tokens: int, replace: bool = False): """ Updates the request information object with the number of tokens received. @@ -824,7 +837,7 @@ class ClientBase: self.request_information.tokens = tokens else: self.request_information.tokens += tokens - + async def send_prompt( self, prompt: str, @@ -837,13 +850,13 @@ class ClientBase: :param prompt: The text prompt to send. :return: The AI's response text. """ - + try: return await self._send_prompt(prompt, kind, finalize, retries) except GenerationCancelled: await self.abort_generation() raise - + async def _send_prompt( self, prompt: str, @@ -856,47 +869,49 @@ class ClientBase: :param prompt: The text prompt to send. :return: The AI's response text. """ - + try: if self.rate_limit_counter: - aborted:bool = False + aborted: bool = False while not self.rate_limit_counter.increment(): log.warn("Rate limit exceeded", client=self.name) emit( "rate_limited", - message="Rate limit exceeded", - status="error", + message="Rate limit exceeded", + status="error", websocket_passthrough=True, data={ "client": self.name, "rate_limit": self.rate_limit, "reset_time": self.rate_limit_counter.reset_time(), - } + }, ) - + scene = active_scene.get() if not scene or not scene.active or scene.cancel_requested: - log.info("Rate limit exceeded, generation cancelled", client=self.name) + log.info( + "Rate limit exceeded, generation cancelled", + client=self.name, + ) aborted = True break - + await asyncio.sleep(1) - + emit( "rate_limit_reset", message="Rate limit reset", status="info", websocket_passthrough=True, - data={"client": self.name} + data={"client": self.name}, ) - + if aborted: raise GenerationCancelled("Generation cancelled") except GenerationCancelled: raise - except Exception as e: + except Exception: log.error("Error during rate limit check", e=traceback.format_exc()) - if not active_scene.get(): log.error("SceneInactiveError", scene=active_scene.get()) @@ -940,25 +955,25 @@ class ClientBase: parameters=prompt_param, ) prompt_sent = self.repetition_adjustment(finalized_prompt) - + self.new_request() - + response = await self._cancelable_generate(prompt_sent, prompt_param, kind) - + self.end_request() - + if isinstance(response, GenerationCancelled): # generation was cancelled raise response - - #response = await self.generate(prompt_sent, prompt_param, kind) - + + # response = await self.generate(prompt_sent, prompt_param, kind) + response, finalized_prompt = await self.auto_break_repetition( finalized_prompt, prompt_param, response, kind, retries ) - + if REPLACE_SMART_QUOTES: - response = response.replace('“', '"').replace('”', '"') + response = response.replace("“", '"').replace("”", '"') time_end = time.time() @@ -992,9 +1007,9 @@ class ClientBase: ) return response - except GenerationCancelled as e: + except GenerationCancelled: raise - except Exception as e: + except Exception: self.log.error("send_prompt error", e=traceback.format_exc()) emit( "status", message="Error during generation (check logs)", status="error" @@ -1004,7 +1019,7 @@ class ClientBase: self.emit_status(processing=False) self._returned_prompt_tokens = None self._returned_response_tokens = None - + if self.rate_limit_counter: self.rate_limit_counter.increment() diff --git a/src/talemate/client/cohere.py b/src/talemate/client/cohere.py index 83872dd9..06fe26c4 100644 --- a/src/talemate/client/cohere.py +++ b/src/talemate/client/cohere.py @@ -2,7 +2,13 @@ import pydantic import structlog from cohere import AsyncClientV2 -from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField +from talemate.client.base import ( + ClientBase, + ErrorAction, + ParameterReroute, + CommonDefaults, + ExtraField, +) from talemate.client.registry import register from talemate.client.remote import ( EndpointOverride, @@ -51,7 +57,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase): auto_break_repetition_enabled = False decensor_enabled = True config_cls = ClientConfig - + class Meta(ClientBase.Meta): name_prefix: str = "Cohere" title: str = "Cohere" @@ -115,12 +121,12 @@ class CohereClient(EndpointOverrideMixin, ClientBase): self.current_status = status - data={ + data = { "error_action": error_action.model_dump() if error_action else None, "meta": self.Meta().model_dump(), "enabled": self.enabled, } - data.update(self._common_status_data()) + data.update(self._common_status_data()) emit( "client_status", message=self.client_type, @@ -171,7 +177,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase): if "enabled" in kwargs: self.enabled = bool(kwargs["enabled"]) - + self._reconfigure_common_parameters(**kwargs) self._reconfigure_endpoint_override(**kwargs) @@ -200,7 +206,6 @@ class CohereClient(EndpointOverrideMixin, ClientBase): return prompt def clean_prompt_parameters(self, parameters: dict): - super().clean_prompt_parameters(parameters) # if temperature is set, it needs to be clamped between 0 and 1.0 @@ -240,7 +245,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase): parameters=parameters, system_message=system_message, ) - + messages = [ { "role": "system", @@ -249,7 +254,7 @@ class CohereClient(EndpointOverrideMixin, ClientBase): { "role": "user", "content": human_message, - } + }, ] try: @@ -290,5 +295,5 @@ class CohereClient(EndpointOverrideMixin, ClientBase): # self.log.error("generate error", e=e) # emit("status", message="cohere API: Permission Denied", status="error") # return "" - except Exception as e: + except Exception: raise diff --git a/src/talemate/client/context.py b/src/talemate/client/context.py index 5d407c78..7d4ea447 100644 --- a/src/talemate/client/context.py +++ b/src/talemate/client/context.py @@ -23,7 +23,10 @@ def model_to_dict_without_defaults(model_instance): if field.default == model_dict.get(field_name): del model_dict[field_name] # special case for conversation context, dont copy if talking_character is None - if field_name == "conversation" and model_dict.get(field_name).get("talking_character") is None: + if ( + field_name == "conversation" + and model_dict.get(field_name).get("talking_character") is None + ): del model_dict[field_name] return model_dict @@ -102,7 +105,7 @@ class ClientContext: # Update the context data self.token = context_data.set(data) - + def __exit__(self, exc_type, exc_val, exc_tb): """ Reset the context variable `context_data` to its previous values when exiting the context. diff --git a/src/talemate/client/deepseek.py b/src/talemate/client/deepseek.py index b8523a6a..dfa536a5 100644 --- a/src/talemate/client/deepseek.py +++ b/src/talemate/client/deepseek.py @@ -1,8 +1,5 @@ -import json - import pydantic import structlog -import tiktoken from openai import AsyncOpenAI, PermissionDeniedError from talemate.client.base import ClientBase, ErrorAction, CommonDefaults @@ -103,12 +100,12 @@ class DeepSeekClient(ClientBase): self.current_status = status - data={ + data = { "error_action": error_action.model_dump() if error_action else None, "meta": self.Meta().model_dump(), "enabled": self.enabled, } - data.update(self._common_status_data()) + data.update(self._common_status_data()) emit( "client_status", message=self.client_type, @@ -273,5 +270,5 @@ class DeepSeekClient(ClientBase): self.log.error("generate error", e=e) emit("status", message="DeepSeek API: Permission Denied", status="error") return "" - except Exception as e: + except Exception: raise diff --git a/src/talemate/client/google.py b/src/talemate/client/google.py index 7175eaac..faebddc1 100644 --- a/src/talemate/client/google.py +++ b/src/talemate/client/google.py @@ -7,7 +7,13 @@ from google import genai import google.genai.types as genai_types from google.genai.errors import APIError -from talemate.client.base import ClientBase, ErrorAction, ExtraField, ParameterReroute, CommonDefaults +from talemate.client.base import ( + ClientBase, + ErrorAction, + ExtraField, + ParameterReroute, + CommonDefaults, +) from talemate.client.registry import register from talemate.client.remote import ( RemoteServiceMixin, @@ -41,12 +47,14 @@ SUPPORTED_MODELS = [ "gemini-2.5-pro-preview-06-05", ] + class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel): max_token_length: int = 16384 model: str = "gemini-2.0-flash" disable_safety_settings: bool = False double_coercion: str = None + class ClientConfig(EndpointOverride, BaseClientConfig): disable_safety_settings: bool = False @@ -80,7 +88,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): ), } extra_fields.update(endpoint_override_extra_fields()) - def __init__(self, model="gemini-2.0-flash", **kwargs): self.model_name = model @@ -114,24 +121,28 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): @property def google_location(self): return self.config.get("google").get("gcloud_location") - + @property def google_api_key(self): return self.config.get("google").get("api_key") @property def vertexai_ready(self) -> bool: - return all([ - self.google_credentials_path, - self.google_location, - ]) + return all( + [ + self.google_credentials_path, + self.google_location, + ] + ) @property def developer_api_ready(self) -> bool: - return all([ - self.google_api_key, - ]) - + return all( + [ + self.google_api_key, + ] + ) + @property def using(self) -> str: if self.developer_api_ready: @@ -143,7 +154,11 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): @property def ready(self): # all google settings must be set - return self.vertexai_ready or self.developer_api_ready or self.endpoint_override_base_url_configured + return ( + self.vertexai_ready + or self.developer_api_ready + or self.endpoint_override_base_url_configured + ) @property def safety_settings(self): @@ -179,10 +194,8 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): def http_options(self) -> genai_types.HttpOptions | None: if not self.endpoint_override_base_url_configured: return None - - return genai_types.HttpOptions( - base_url=self.base_url - ) + + return genai_types.HttpOptions(base_url=self.base_url) @property def supported_parameters(self): @@ -230,7 +243,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): "meta": self.Meta().model_dump(), "enabled": self.enabled, } - data.update(self._common_status_data()) + data.update(self._common_status_data()) self.populate_extra_fields(data) if self.using == "VertexAI": @@ -252,7 +265,9 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): try: self.client.http_options.base_url = base_url except Exception as e: - log.error("Error setting client base URL", error=e, client=self.client_type) + log.error( + "Error setting client base URL", error=e, client=self.client_type + ) def set_client(self, max_token_length: int = None, **kwargs): if not self.ready: @@ -283,7 +298,9 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): location=self.google_location, ) else: - self.client = genai.Client(api_key=self.api_key or None, http_options=self.http_options) + self.client = genai.Client( + api_key=self.api_key or None, http_options=self.http_options + ) log.info( "google set client", @@ -292,7 +309,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): model=model, ) - def response_tokens(self, response:str): + def response_tokens(self, response: str): """Return token count for a response which may be a string or SDK object.""" return count_tokens(response) @@ -309,10 +326,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): if "enabled" in kwargs: self.enabled = bool(kwargs["enabled"]) - + if "double_coercion" in kwargs: self.double_coercion = kwargs["double_coercion"] - + self._reconfigure_common_parameters(**kwargs) def clean_prompt_parameters(self, parameters: dict): @@ -329,7 +346,6 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): """ return prompt - async def generate(self, prompt: str, parameters: dict, kind: str): """ Generates text from the given prompt and parameters. @@ -342,7 +358,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): human_message = prompt.strip() system_message = self.get_system_message(kind) - + contents = [ genai_types.Content( role="user", @@ -350,10 +366,10 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): genai_types.Part.from_text( text=human_message, ) - ] + ], ) ] - + if coercion_prompt: contents.append( genai_types.Content( @@ -362,7 +378,7 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): genai_types.Part.from_text( text=coercion_prompt, ) - ] + ], ) ) @@ -378,24 +394,23 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): try: # Use streaming so we can update_Request_tokens incrementally - #stream = await chat.send_message_async( + # stream = await chat.send_message_async( # human_message, # safety_settings=self.safety_settings, # generation_config=parameters, # stream=True - #) + # ) - stream = await self.client.aio.models.generate_content_stream( model=self.model_name, contents=contents, config=genai_types.GenerateContentConfig( safety_settings=self.safety_settings, http_options=self.http_options, - **parameters + **parameters, ), ) - + response = "" async for chunk in stream: @@ -425,5 +440,5 @@ class GoogleClient(EndpointOverrideMixin, RemoteServiceMixin, ClientBase): self.log.error("generate error", e=e) emit("status", message="google API: API Error", status="error") return "" - except Exception as e: + except Exception: raise diff --git a/src/talemate/client/groq.py b/src/talemate/client/groq.py index 2db7cab6..b9e52569 100644 --- a/src/talemate/client/groq.py +++ b/src/talemate/client/groq.py @@ -260,5 +260,5 @@ class GroqClient(EndpointOverrideMixin, ClientBase): self.log.error("generate error", e=e) emit("status", message="OpenAI API: Permission Denied", status="error") return "" - except Exception as e: + except Exception: raise diff --git a/src/talemate/client/koboldcpp.py b/src/talemate/client/koboldcpp.py index 714f6dcd..32b1e539 100644 --- a/src/talemate/client/koboldcpp.py +++ b/src/talemate/client/koboldcpp.py @@ -17,7 +17,7 @@ from talemate.client.base import ( ClientBase, Defaults, ParameterReroute, - ClientEmbeddingsStatus + ClientEmbeddingsStatus, ) from talemate.client.registry import register import talemate.emit.async_signals as async_signals @@ -46,9 +46,13 @@ class KoboldEmbeddingFunction(EmbeddingFunction): """ Embed a list of input texts using the KoboldCPP embeddings endpoint. """ - - log.debug("KoboldCppEmbeddingFunction", api_url=self.api_url, model_name=self.model_name) - + + log.debug( + "KoboldCppEmbeddingFunction", + api_url=self.api_url, + model_name=self.model_name, + ) + # Prepare the request payload for KoboldCPP. Include model name if required. payload = {"input": texts} if self.model_name is not None: @@ -65,6 +69,8 @@ class KoboldEmbeddingFunction(EmbeddingFunction): embeddings = [item["embedding"] for item in embedding_results] return embeddings + + @register() class KoboldCppClient(ClientBase): auto_determine_prompt_template: bool = True @@ -147,7 +153,6 @@ class KoboldCppClient(ClientBase): talemate_parameter="stopping_strings", client_parameter="stop_sequence", ), - "xtc_threshold", "xtc_probability", "dry_multiplier", @@ -155,7 +160,6 @@ class KoboldCppClient(ClientBase): "dry_allowed_length", "dry_sequence_breakers", "smoothing_factor", - "temperature", ] @@ -172,18 +176,18 @@ class KoboldCppClient(ClientBase): @property def supports_embeddings(self) -> bool: return True - + @property def embeddings_url(self) -> str: if self.is_openai: return urljoin(self.api_url, "embeddings") else: return urljoin(self.api_url, "api/extra/embeddings") - + @property def embeddings_function(self): return KoboldEmbeddingFunction(self.embeddings_url, self.embeddings_model_name) - + def api_endpoint_specified(self, url: str) -> bool: return "/v1" in self.api_url @@ -208,10 +212,10 @@ class KoboldCppClient(ClientBase): # if self._embeddings_model_name is set, return it if self.embeddings_model_name: return self.embeddings_model_name - + # otherwise, get the model name by doing a request to # the embeddings endpoint with a single character - + async with httpx.AsyncClient() as client: response = await client.post( self.embeddings_url, @@ -219,37 +223,40 @@ class KoboldCppClient(ClientBase): timeout=2, headers=self.request_headers, ) - + response_data = response.json() self._embeddings_model_name = response_data.get("model") return self._embeddings_model_name - + async def get_embeddings_status(self): url_version = urljoin(self.api_url, "api/extra/version") async with httpx.AsyncClient() as client: response = await client.get(url_version, timeout=2) response_data = response.json() self._embeddings_status = response_data.get("embeddings", False) - + if not self.embeddings_status or self.embeddings_model_name: return - + await self.get_embeddings_model_name() - - log.debug("KoboldCpp embeddings are enabled, suggesting embeddings", model_name=self.embeddings_model_name) - + + log.debug( + "KoboldCpp embeddings are enabled, suggesting embeddings", + model_name=self.embeddings_model_name, + ) + self.set_embeddings() - + await async_signals.get("client.embeddings_available").send( ClientEmbeddingsStatus( client=self, embedding_name=self.embeddings_model_name, ) ) - + async def get_model_name(self): self.ensure_api_endpoint_specified() - + try: async with httpx.AsyncClient() as client: response = await client.get( @@ -275,7 +282,7 @@ class KoboldCppClient(ClientBase): # split by "/" and take last if model_name: model_name = model_name.split("/")[-1] - + await self.get_embeddings_status() return model_name @@ -309,7 +316,6 @@ class KoboldCppClient(ClientBase): tokencount = len(response.json().get("ids", [])) return tokencount - async def abort_generation(self): """ Trigger the stop generation endpoint @@ -317,7 +323,7 @@ class KoboldCppClient(ClientBase): if self.is_openai: # openai api endpoint doesn't support abort return - + parts = urlparse(self.api_url) url_abort = f"{parts.scheme}://{parts.netloc}/api/extra/abort" async with httpx.AsyncClient() as client: @@ -325,7 +331,7 @@ class KoboldCppClient(ClientBase): url_abort, headers=self.request_headers, ) - + async def generate(self, prompt: str, parameters: dict, kind: str): """ Generates text from the given prompt and parameters. @@ -334,14 +340,16 @@ class KoboldCppClient(ClientBase): return await self._generate_openai(prompt, parameters, kind) else: loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self._generate_kcpp_stream, prompt, parameters, kind) - + return await loop.run_in_executor( + None, self._generate_kcpp_stream, prompt, parameters, kind + ) + def _generate_kcpp_stream(self, prompt: str, parameters: dict, kind: str): """ Generates text from the given prompt and parameters. """ parameters["prompt"] = prompt.strip(" ") - + response = "" parameters["stream"] = True stream_response = requests.post( @@ -352,15 +360,15 @@ class KoboldCppClient(ClientBase): stream=True, ) stream_response.raise_for_status() - + sse = sseclient.SSEClient(stream_response) - + for event in sse.events(): payload = json.loads(event.data) - chunk = payload['token'] + chunk = payload["token"] response += chunk self.update_request_tokens(self.count_tokens(chunk)) - + return response async def _generate_openai(self, prompt: str, parameters: dict, kind: str): @@ -447,7 +455,6 @@ class KoboldCppClient(ClientBase): sd_models_url = urljoin(self.url, "/sdapi/v1/sd-models") async with httpx.AsyncClient() as client: - try: response = await client.get(url=sd_models_url, timeout=2) except Exception as exc: diff --git a/src/talemate/client/lmstudio.py b/src/talemate/client/lmstudio.py index df2b450d..76682519 100644 --- a/src/talemate/client/lmstudio.py +++ b/src/talemate/client/lmstudio.py @@ -37,13 +37,13 @@ class LMStudioClient(ClientBase): def reconfigure(self, **kwargs): super().reconfigure(**kwargs) - + if self.client and self.client.base_url != self.api_url: self.set_client() async def get_model_name(self): model_name = await super().get_model_name() - + # model name comes back as a file path, so we need to extract the model name # the path could be windows or linux so it needs to handle both backslash and forward slash diff --git a/src/talemate/client/mistral.py b/src/talemate/client/mistral.py index f1f1f14c..11276db8 100644 --- a/src/talemate/client/mistral.py +++ b/src/talemate/client/mistral.py @@ -1,10 +1,14 @@ import pydantic import structlog -from typing import Literal from mistralai import Mistral from mistralai.models.sdkerror import SDKError -from talemate.client.base import ClientBase, ErrorAction, ParameterReroute, CommonDefaults, ExtraField +from talemate.client.base import ( + ClientBase, + ErrorAction, + CommonDefaults, + ExtraField, +) from talemate.client.registry import register from talemate.client.remote import ( EndpointOverride, @@ -44,9 +48,11 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel): max_token_length: int = 16384 model: str = "open-mixtral-8x22b" + class ClientConfig(EndpointOverride, BaseClientConfig): pass + @register() class MistralAIClient(EndpointOverrideMixin, ClientBase): """ @@ -68,7 +74,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase): requires_prompt_template: bool = False defaults: Defaults = Defaults() extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields() - + def __init__(self, model="open-mixtral-8x22b", **kwargs): self.model_name = model self.api_key_status = None @@ -115,7 +121,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase): model_name = "No model loaded" self.current_status = status - data={ + data = { "error_action": error_action.model_dump() if error_action else None, "meta": self.Meta().model_dump(), "enabled": self.enabled, @@ -167,10 +173,10 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase): def reconfigure(self, **kwargs): if "enabled" in kwargs: self.enabled = bool(kwargs["enabled"]) - + self._reconfigure_common_parameters(**kwargs) self._reconfigure_endpoint_override(**kwargs) - + if kwargs.get("model"): self.model_name = kwargs["model"] self.set_client(kwargs.get("max_token_length")) @@ -248,14 +254,16 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase): ) response = "" - + completion_tokens = 0 prompt_tokens = 0 async for event in event_stream: if event.data.choices: response += event.data.choices[0].delta.content - self.update_request_tokens(self.count_tokens(event.data.choices[0].delta.content)) + self.update_request_tokens( + self.count_tokens(event.data.choices[0].delta.content) + ) if event.data.usage: completion_tokens += event.data.usage.completion_tokens prompt_tokens += event.data.usage.prompt_tokens @@ -263,7 +271,7 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase): self._returned_prompt_tokens = prompt_tokens self._returned_response_tokens = completion_tokens - #response = response.choices[0].message.content + # response = response.choices[0].message.content # older models don't support json_object response coersion # and often like to return the response wrapped in ```json @@ -282,12 +290,12 @@ class MistralAIClient(EndpointOverrideMixin, ClientBase): return response except SDKError as e: self.log.error("generate error", e=e) - if hasattr(e, 'status_code') and e.status_code in [403, 401]: + if hasattr(e, "status_code") and e.status_code in [403, 401]: emit( "status", message="mistral.ai API: Permission Denied", status="error", ) return "" - except Exception as e: + except Exception: raise diff --git a/src/talemate/client/model_prompts.py b/src/talemate/client/model_prompts.py index edbd4a4c..042f1c9a 100644 --- a/src/talemate/client/model_prompts.py +++ b/src/talemate/client/model_prompts.py @@ -117,7 +117,6 @@ class ModelPrompt: prompt = f"{prompt}<|BOT|>" if "<|BOT|>" in prompt: - response_str = f"{double_coercion}{response_str}" if "\n<|BOT|>" in prompt: @@ -264,7 +263,11 @@ class Mistralv7TekkenIdentifier(TemplateIdentifier): template_str = "MistralV7Tekken" def __call__(self, content: str): - return "[SYSTEM_PROMPT]" in content and "[INST]" in content and "[/INST]" in content + return ( + "[SYSTEM_PROMPT]" in content + and "[INST]" in content + and "[/INST]" in content + ) @register_template_identifier diff --git a/src/talemate/client/ollama.py b/src/talemate/client/ollama.py index b104233c..3836c19f 100644 --- a/src/talemate/client/ollama.py +++ b/src/talemate/client/ollama.py @@ -1,11 +1,16 @@ -import asyncio import structlog import httpx import ollama import time -from typing import Union -from talemate.client.base import STOPPING_STRINGS, ClientBase, CommonDefaults, ErrorAction, ParameterReroute, ExtraField +from talemate.client.base import ( + STOPPING_STRINGS, + ClientBase, + CommonDefaults, + ErrorAction, + ParameterReroute, + ExtraField, +) from talemate.client.registry import register from talemate.config import Client as BaseClientConfig @@ -14,27 +19,30 @@ log = structlog.get_logger("talemate.client.ollama") FETCH_MODELS_INTERVAL = 15 + class OllamaClientDefaults(CommonDefaults): api_url: str = "http://localhost:11434" # Default Ollama URL model: str = "" # Allow empty default, will fetch from Ollama api_handles_prompt_template: bool = False allow_thinking: bool = False + class ClientConfig(BaseClientConfig): api_handles_prompt_template: bool = False allow_thinking: bool = False + @register() class OllamaClient(ClientBase): """ Ollama client for generating text using locally hosted models. """ - + auto_determine_prompt_template: bool = True client_type = "ollama" conversation_retries = 0 config_cls = ClientConfig - + class Meta(ClientBase.Meta): name_prefix: str = "Ollama" title: str = "Ollama" @@ -56,27 +64,26 @@ class OllamaClient(ClientBase): label="Allow thinking", required=False, description="Allow the model to think before responding. Talemate does not have a good way to deal with this yet, so it's recommended to leave this off.", - ) + ), } - + @property def supported_parameters(self): # Parameters supported by Ollama's generate endpoint # Based on the API documentation return [ "temperature", - "top_p", + "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", ParameterReroute( talemate_parameter="repetition_penalty", - client_parameter="repeat_penalty" + client_parameter="repeat_penalty", ), ParameterReroute( - talemate_parameter="max_tokens", - client_parameter="num_predict" + talemate_parameter="max_tokens", client_parameter="num_predict" ), "stopping_strings", # internal parameters that will be removed before sending @@ -98,7 +105,13 @@ class OllamaClient(ClientBase): """ return self.allow_thinking - def __init__(self, model=None, api_handles_prompt_template=False, allow_thinking=False, **kwargs): + def __init__( + self, + model=None, + api_handles_prompt_template=False, + allow_thinking=False, + **kwargs, + ): self.model_name = model self.api_handles_prompt_template = api_handles_prompt_template self.allow_thinking = allow_thinking @@ -114,16 +127,14 @@ class OllamaClient(ClientBase): # Update model if provided if kwargs.get("model"): self.model_name = kwargs["model"] - + # Create async client with the configured API URL # Ollama's AsyncClient expects just the base URL without any path self.client = ollama.AsyncClient(host=self.api_url) self.api_handles_prompt_template = kwargs.get( "api_handles_prompt_template", self.api_handles_prompt_template - ) - self.allow_thinking = kwargs.get( - "allow_thinking", self.allow_thinking ) + self.allow_thinking = kwargs.get("allow_thinking", self.allow_thinking) async def status(self): """ @@ -131,7 +142,7 @@ class OllamaClient(ClientBase): Raises an error if no model name is returned. :return: None """ - + if self.processing: self.emit_status() return @@ -140,7 +151,7 @@ class OllamaClient(ClientBase): self.connected = False self.emit_status() return - + try: # instead of using the client (which apparently cannot set a timeout per endpoint) # we use httpx to check {api_url}/api/version to see if the server is running @@ -148,7 +159,7 @@ class OllamaClient(ClientBase): async with httpx.AsyncClient() as client: response = await client.get(f"{self.api_url}/api/version", timeout=2) response.raise_for_status() - + # if the server is running, fetch the available models await self.fetch_available_models() except Exception as e: @@ -156,7 +167,7 @@ class OllamaClient(ClientBase): self.connected = False self.emit_status() return - + await super().status() async def fetch_available_models(self): @@ -165,7 +176,7 @@ class OllamaClient(ClientBase): """ if time.time() - self._models_last_fetched < FETCH_MODELS_INTERVAL: return self._available_models - + response = await self.client.list() models = response.get("models", []) model_names = [model.model for model in models] @@ -182,7 +193,7 @@ class OllamaClient(ClientBase): async def get_model_name(self): return self.model_name - + def prompt_template(self, system_message: str, prompt: str): if not self.api_handles_prompt_template: return super().prompt_template(system_message, prompt) @@ -201,10 +212,12 @@ class OllamaClient(ClientBase): Tune parameters for Ollama's generate endpoint. """ super().tune_prompt_parameters(parameters, kind) - + # Build stopping strings list - parameters["stop"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", []) - + parameters["stop"] = STOPPING_STRINGS + parameters.get( + "extra_stopping_strings", [] + ) + # Ollama uses num_predict instead of max_tokens if "max_tokens" in parameters: parameters["num_predict"] = parameters["max_tokens"] @@ -215,7 +228,7 @@ class OllamaClient(ClientBase): """ # First let parent class handle parameter reroutes and cleanup super().clean_prompt_parameters(parameters) - + # Remove our internal parameters if "extra_stopping_strings" in parameters: del parameters["extra_stopping_strings"] @@ -223,7 +236,7 @@ class OllamaClient(ClientBase): del parameters["stopping_strings"] if "stream" in parameters: del parameters["stream"] - + # Remove max_tokens as we've already converted it to num_predict if "max_tokens" in parameters: del parameters["max_tokens"] @@ -237,12 +250,12 @@ class OllamaClient(ClientBase): await self.get_model_name() if not self.model_name: raise Exception("No model specified or available in Ollama") - + # Prepare options for Ollama options = parameters - + options["num_ctx"] = self.max_token_length - + try: # Use generate endpoint for completion stream = await self.client.generate( @@ -253,22 +266,21 @@ class OllamaClient(ClientBase): think=self.can_think, stream=True, ) - + response = "" - + async for part in stream: content = part.response response += content self.update_request_tokens(self.count_tokens(content)) - + # Extract the response text return response - + except Exception as e: log.error("Ollama generation error", error=str(e), model=self.model_name) raise ErrorAction( - message=f"Ollama generation failed: {str(e)}", - title="Generation Error" + message=f"Ollama generation failed: {str(e)}", title="Generation Error" ) async def abort_generation(self): @@ -284,7 +296,7 @@ class OllamaClient(ClientBase): Adjusts temperature and repetition_penalty by random values. """ import random - + temp = prompt_config["temperature"] rep_pen = prompt_config.get("repetition_penalty", 1.0) @@ -302,12 +314,12 @@ class OllamaClient(ClientBase): # Handle model update if kwargs.get("model"): self.model_name = kwargs["model"] - + super().reconfigure(**kwargs) - + # Re-initialize client if API URL changed or model changed if "api_url" in kwargs or "model" in kwargs: self.set_client(**kwargs) - + if "api_handles_prompt_template" in kwargs: self.api_handles_prompt_template = kwargs["api_handles_prompt_template"] diff --git a/src/talemate/client/openai.py b/src/talemate/client/openai.py index 1491c6ca..81f28289 100644 --- a/src/talemate/client/openai.py +++ b/src/talemate/client/openai.py @@ -108,9 +108,11 @@ class Defaults(EndpointOverride, CommonDefaults, pydantic.BaseModel): max_token_length: int = 16384 model: str = "gpt-4o" + class ClientConfig(EndpointOverride, BaseClientConfig): pass + @register() class OpenAIClient(EndpointOverrideMixin, ClientBase): """ @@ -123,7 +125,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase): # TODO: make this configurable? decensor_enabled = False config_cls = ClientConfig - + class Meta(ClientBase.Meta): name_prefix: str = "OpenAI" title: str = "OpenAI" @@ -132,6 +134,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase): requires_prompt_template: bool = False defaults: Defaults = Defaults() extra_fields: dict[str, ExtraField] = endpoint_override_extra_fields() + def __init__(self, model="gpt-4o", **kwargs): self.model_name = model self.api_key_status = None @@ -181,12 +184,12 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase): self.current_status = status - data={ + data = { "error_action": error_action.model_dump() if error_action else None, "meta": self.Meta().model_dump(), "enabled": self.enabled, } - data.update(self._common_status_data()) + data.update(self._common_status_data()) emit( "client_status", @@ -305,32 +308,34 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase): human_message = {"role": "user", "content": prompt.strip()} system_message = {"role": "system", "content": self.get_system_message(kind)} - + # o1 and o3 models don't support system_message if "o1" in self.model_name or "o3" in self.model_name: - messages=[human_message] + messages = [human_message] # paramters need to be munged # `max_tokens` becomes `max_completion_tokens` if "max_tokens" in parameters: parameters["max_completion_tokens"] = parameters.pop("max_tokens") - + # temperature forced to 1 if "temperature" in parameters: - log.debug(f"{self.model_name} does not support temperature, forcing to 1") + log.debug( + f"{self.model_name} does not support temperature, forcing to 1" + ) parameters["temperature"] = 1 - + unsupported_params = [ "presence_penalty", "top_p", ] - + for param in unsupported_params: if param in parameters: log.debug(f"{self.model_name} does not support {param}, removing") parameters.pop(param) - + else: - messages=[system_message, human_message] + messages = [system_message, human_message] self.log.debug( "generate", @@ -346,7 +351,7 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase): stream=True, **parameters, ) - + response = "" # Iterate over streamed chunks @@ -359,9 +364,9 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase): response += content_piece # Incrementally track token usage self.update_request_tokens(self.count_tokens(content_piece)) - - #self._returned_prompt_tokens = self.prompt_tokens(prompt) - #self._returned_response_tokens = self.response_tokens(response) + + # self._returned_prompt_tokens = self.prompt_tokens(prompt) + # self._returned_response_tokens = self.response_tokens(response) # older models don't support json_object response coersion # and often like to return the response wrapped in ```json @@ -382,5 +387,5 @@ class OpenAIClient(EndpointOverrideMixin, ClientBase): self.log.error("generate error", e=e) emit("status", message="OpenAI API: Permission Denied", status="error") return "" - except Exception as e: + except Exception: raise diff --git a/src/talemate/client/openai_compat.py b/src/talemate/client/openai_compat.py index 85df8710..d86a0a88 100644 --- a/src/talemate/client/openai_compat.py +++ b/src/talemate/client/openai_compat.py @@ -1,9 +1,8 @@ import random -import urllib import pydantic import structlog -from openai import AsyncOpenAI, NotFoundError, PermissionDeniedError +from openai import AsyncOpenAI, PermissionDeniedError from talemate.client.base import ClientBase, ExtraField from talemate.client.registry import register @@ -93,7 +92,6 @@ class OpenAICompatibleClient(ClientBase): ) def prompt_template(self, system_message: str, prompt: str): - log.debug( "IS API HANDLING PROMPT TEMPLATE", api_handles_prompt_template=self.api_handles_prompt_template, @@ -130,7 +128,10 @@ class OpenAICompatibleClient(ClientBase): ) human_message = {"role": "user", "content": prompt.strip()} response = await self.client.chat.completions.create( - model=self.model_name, messages=[human_message], stream=False, **parameters + model=self.model_name, + messages=[human_message], + stream=False, + **parameters, ) response = response.choices[0].message.content return self.process_response_for_indirect_coercion(prompt, response) @@ -177,7 +178,7 @@ class OpenAICompatibleClient(ClientBase): if "double_coercion" in kwargs: self.double_coercion = kwargs["double_coercion"] - + if "rate_limit" in kwargs: self.rate_limit = kwargs["rate_limit"] diff --git a/src/talemate/client/openrouter.py b/src/talemate/client/openrouter.py index 7cafca58..1878630f 100644 --- a/src/talemate/client/openrouter.py +++ b/src/talemate/client/openrouter.py @@ -21,25 +21,25 @@ AVAILABLE_MODELS = [] DEFAULT_MODEL = "" MODELS_FETCHED = False + async def fetch_available_models(api_key: str = None): """Fetch available models from OpenRouter API""" global AVAILABLE_MODELS, DEFAULT_MODEL, MODELS_FETCHED - + if not api_key: return [] - + if MODELS_FETCHED: return AVAILABLE_MODELS - + # Only fetch if we haven't already or if explicitly requested if AVAILABLE_MODELS and not api_key: return AVAILABLE_MODELS - + try: async with httpx.AsyncClient() as client: response = await client.get( - "https://openrouter.ai/api/v1/models", - timeout=10.0 + "https://openrouter.ai/api/v1/models", timeout=10.0 ) if response.status_code == 200: data = response.json() @@ -51,23 +51,26 @@ async def fetch_available_models(api_key: str = None): AVAILABLE_MODELS = sorted(models) log.debug(f"Fetched {len(AVAILABLE_MODELS)} models from OpenRouter") else: - log.warning(f"Failed to fetch models from OpenRouter: {response.status_code}") + log.warning( + f"Failed to fetch models from OpenRouter: {response.status_code}" + ) except Exception as e: log.error(f"Error fetching models from OpenRouter: {e}") - + MODELS_FETCHED = True return AVAILABLE_MODELS - def fetch_models_sync(event): api_key = event.data.get("openrouter", {}).get("api_key") loop = asyncio.get_event_loop() loop.run_until_complete(fetch_available_models(api_key)) + handlers["config_saved"].connect(fetch_models_sync) handlers["talemate_started"].connect(fetch_models_sync) + class Defaults(CommonDefaults, pydantic.BaseModel): max_token_length: int = 16384 model: str = DEFAULT_MODEL @@ -89,7 +92,9 @@ class OpenRouterClient(ClientBase): name_prefix: str = "OpenRouter" title: str = "OpenRouter" manual_model: bool = True - manual_model_choices: list[str] = pydantic.Field(default_factory=lambda: AVAILABLE_MODELS) + manual_model_choices: list[str] = pydantic.Field( + default_factory=lambda: AVAILABLE_MODELS + ) requires_prompt_template: bool = False defaults: Defaults = Defaults() @@ -169,7 +174,7 @@ class OpenRouterClient(ClientBase): def set_client(self, max_token_length: int = None): # Unlike other clients, we don't need to set up a client instance # We'll use httpx directly in the generate method - + if not self.openrouter_api_key: log.error("No OpenRouter API key set") if self.api_key_status: @@ -183,7 +188,7 @@ class OpenRouterClient(ClientBase): if max_token_length and not isinstance(max_token_length, int): max_token_length = int(max_token_length) - + # Set max token length (default to 16k if not specified) self.max_token_length = max_token_length or 16384 @@ -221,7 +226,7 @@ class OpenRouterClient(ClientBase): self._models_fetched = True # Update the Meta class with new model choices self.Meta.manual_model_choices = AVAILABLE_MODELS - + self.emit_status() def prompt_template(self, system_message: str, prompt: str): @@ -240,13 +245,13 @@ class OpenRouterClient(ClientBase): raise Exception("No OpenRouter API key set") prompt, coercion_prompt = self.split_prompt_for_coercion(prompt) - + # Prepare messages for chat completion messages = [ {"role": "system", "content": self.get_system_message(kind)}, - {"role": "user", "content": prompt.strip()} + {"role": "user", "content": prompt.strip()}, ] - + if coercion_prompt: messages.append({"role": "assistant", "content": coercion_prompt.strip()}) @@ -255,7 +260,7 @@ class OpenRouterClient(ClientBase): "model": self.model_name, "messages": messages, "stream": True, - **parameters + **parameters, } self.log.debug( @@ -264,7 +269,7 @@ class OpenRouterClient(ClientBase): parameters=parameters, model=self.model_name, ) - + response_text = "" buffer = "" completion_tokens = 0 @@ -279,46 +284,52 @@ class OpenRouterClient(ClientBase): "Content-Type": "application/json", }, json=payload, - timeout=120.0 # 2 minute timeout for generation + timeout=120.0, # 2 minute timeout for generation ) as response: async for chunk in response.aiter_text(): buffer += chunk - + while True: # Find the next complete SSE line - line_end = buffer.find('\n') + line_end = buffer.find("\n") if line_end == -1: break - + line = buffer[:line_end].strip() - buffer = buffer[line_end + 1:] - - if line.startswith('data: '): + buffer = buffer[line_end + 1 :] + + if line.startswith("data: "): data = line[6:] - if data == '[DONE]': + if data == "[DONE]": break - + try: data_obj = json.loads(data) - content = data_obj["choices"][0]["delta"].get("content") + content = data_obj["choices"][0]["delta"].get( + "content" + ) usage = data_obj.get("usage", {}) - completion_tokens += usage.get("completion_tokens", 0) + completion_tokens += usage.get( + "completion_tokens", 0 + ) prompt_tokens += usage.get("prompt_tokens", 0) if content: response_text += content # Update tokens as content streams in - self.update_request_tokens(self.count_tokens(content)) - + self.update_request_tokens( + self.count_tokens(content) + ) + except json.JSONDecodeError: pass - + # Extract the response content response_content = response_text self._returned_prompt_tokens = prompt_tokens self._returned_response_tokens = completion_tokens - + return response_content - + except httpx.ConnectTimeout: self.log.error("OpenRouter API timeout") emit("status", message="OpenRouter API: Request timed out", status="error") @@ -326,4 +337,4 @@ class OpenRouterClient(ClientBase): except Exception as e: self.log.error("generate error", e=e) emit("status", message=f"OpenRouter API Error: {str(e)}", status="error") - raise \ No newline at end of file + raise diff --git a/src/talemate/client/presets.py b/src/talemate/client/presets.py index c58273f2..56d5b807 100644 --- a/src/talemate/client/presets.py +++ b/src/talemate/client/presets.py @@ -16,11 +16,6 @@ __all__ = [ "preset_for_kind", "make_kind", "max_tokens_for_kind", - "PRESET_TALEMATE_CONVERSATION", - "PRESET_TALEMATE_CREATOR", - "PRESET_LLAMA_PRECISE", - "PRESET_DIVINE_INTELLECT", - "PRESET_SIMPLE_1", ] log = structlog.get_logger("talemate.client.presets") @@ -42,27 +37,31 @@ def sync_config(event): ) CONFIG["inference_groups"] = { group: InferencePresetGroup(**data) - for group, data in event.data.get("presets", {}).get("inference_groups", {}).items() + for group, data in event.data.get("presets", {}) + .get("inference_groups", {}) + .items() } handlers["config_saved"].connect(sync_config) -def get_inference_parameters(preset_name: str, group:str|None = None) -> dict: +def get_inference_parameters(preset_name: str, group: str | None = None) -> dict: """ Returns the inference parameters for the given preset name. """ - + presets = CONFIG["inference"].model_dump() - + if group: try: group_presets = CONFIG["inference_groups"].get(group).model_dump() presets.update(group_presets["presets"]) except AttributeError: - log.warning(f"Invalid preset group referenced: {group}. Falling back to defaults.") - + log.warning( + f"Invalid preset group referenced: {group}. Falling back to defaults." + ) + if preset_name in presets: return presets[preset_name] @@ -145,7 +144,7 @@ def preset_for_kind(kind: str, client: "ClientBase") -> dict: presets=CONFIG["inference"], ) preset_name = "scene_direction" - + set_client_context_attribute("inference_preset", preset_name) return get_inference_parameters(preset_name, client.preset_group) @@ -197,29 +196,26 @@ def max_tokens_for_kind(kind: str, total_budget: int) -> int: return value if token_value is not None: return token_value - + # finally check if splitting last item off of _ is a number, and then just # return that number kind_split = kind.split("_")[-1] if kind_split.isdigit(): return int(kind_split) - + return 150 # Default value if none of the kinds match - -def make_kind(action_type: str, length: int, expect_json:bool=False) -> str: +def make_kind(action_type: str, length: int, expect_json: bool = False) -> str: """ Creates a kind string based on the preset_arch_type and length. """ if action_type == "analyze" and not expect_json: - kind = f"investigate" + kind = "investigate" else: kind = action_type - + kind = f"{kind}_{length}" - - return kind - - \ No newline at end of file + + return kind diff --git a/src/talemate/client/ratelimit.py b/src/talemate/client/ratelimit.py index 804c07a5..1f7ca904 100644 --- a/src/talemate/client/ratelimit.py +++ b/src/talemate/client/ratelimit.py @@ -1,30 +1,31 @@ from limits.strategies import MovingWindowRateLimiter from limits.storage import MemoryStorage -from limits import parse, RateLimitItemPerMinute +from limits import RateLimitItemPerMinute import time __all__ = ["CounterRateLimiter"] + class CounterRateLimiter: - def __init__(self, rate_per_minute:int=99999, identifier:str="ratelimit"): + def __init__(self, rate_per_minute: int = 99999, identifier: str = "ratelimit"): self.storage = MemoryStorage() self.limiter = MovingWindowRateLimiter(self.storage) self.rate = RateLimitItemPerMinute(rate_per_minute, 1) self.identifier = identifier - - def update_rate_limit(self, rate_per_minute:int): + + def update_rate_limit(self, rate_per_minute: int): """Update the rate limit with a new value""" self.rate = RateLimitItemPerMinute(rate_per_minute, 1) - + def increment(self) -> bool: limiter = self.limiter rate = self.rate return limiter.hit(rate, self.identifier) - + def reset_time(self) -> float: """ Returns the time in seconds until the rate limit is reset """ - + window = self.limiter.get_window_stats(self.rate, self.identifier) - return window.reset_time - time.time() \ No newline at end of file + return window.reset_time - time.time() diff --git a/src/talemate/client/remote.py b/src/talemate/client/remote.py index 2323c69b..ba9ff2ca 100644 --- a/src/talemate/client/remote.py +++ b/src/talemate/client/remote.py @@ -18,56 +18,68 @@ __all__ = [ "EndpointOverrideAPIKeyField", ] + def endpoint_override_extra_fields(): return { "override_base_url": EndpointOverrideBaseURLField(), "override_api_key": EndpointOverrideAPIKeyField(), } + class EndpointOverride(pydantic.BaseModel): override_base_url: str | None = None override_api_key: str | None = None - + + class EndpointOverrideGroup(FieldGroup): name: str = "endpoint_override" label: str = "Endpoint Override" - description: str = ("Override the default base URL used by this client to access the {client_type} service API.\n\n" - "IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. " - "If the override endpoint requires an API key, enter it below.") + description: str = ( + "Override the default base URL used by this client to access the {client_type} service API.\n\n" + "IMPORTANT: Provide an override only if you fully trust the endpoint. When set, the {client_type} API key defined in the global application settings is deliberately ignored to avoid accidental credential leakage. " + "If the override endpoint requires an API key, enter it below." + ) icon: str = "mdi-api" + class EndpointOverrideField(ExtraField): group: EndpointOverrideGroup = pydantic.Field(default_factory=EndpointOverrideGroup) - + + class EndpointOverrideBaseURLField(EndpointOverrideField): name: str = "override_base_url" type: str = "text" label: str = "Base URL" required: bool = False description: str = "Override the base URL for the remote service" - + + class EndpointOverrideAPIKeyField(EndpointOverrideField): name: str = "override_api_key" type: str = "password" label: str = "API Key" required: bool = False description: str = "Override the API key for the remote service" - note: ux_schema.Note = pydantic.Field(default_factory=lambda: ux_schema.Note( - text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.", - color="warning", - )) + note: ux_schema.Note = pydantic.Field( + default_factory=lambda: ux_schema.Note( + text="This is NOT the API key for the official {client_type} API. It is only used when overriding the base URL. The official {client_type} API key can be configured in the application settings.", + color="warning", + ) + ) class EndpointOverrideMixin: override_base_url: str | None = None override_api_key: str | None = None - + def set_client_api_key(self, api_key: str | None): if getattr(self, "client", None): try: self.client.api_key = api_key except Exception as e: - log.error("Error setting client API key", error=e, client=self.client_type) + log.error( + "Error setting client API key", error=e, client=self.client_type + ) @property def api_key(self) -> str | None: @@ -84,14 +96,17 @@ class EndpointOverrideMixin: @property def endpoint_override_base_url_configured(self) -> bool: return self.override_base_url and self.override_base_url.strip() - + @property def endpoint_override_api_key_configured(self) -> bool: return self.override_api_key and self.override_api_key.strip() - + @property def endpoint_override_fully_configured(self) -> bool: - return self.endpoint_override_base_url_configured and self.endpoint_override_api_key_configured + return ( + self.endpoint_override_base_url_configured + and self.endpoint_override_api_key_configured + ) def _reconfigure_endpoint_override(self, **kwargs): if "override_base_url" in kwargs: @@ -100,13 +115,13 @@ class EndpointOverrideMixin: if getattr(self, "client", None) and orig != self.override_base_url: log.info("Reconfiguring client base URL", new=self.override_base_url) self.set_client(kwargs.get("max_token_length")) - + if "override_api_key" in kwargs: self.override_api_key = kwargs["override_api_key"] self.set_client_api_key(self.override_api_key) -class RemoteServiceMixin: +class RemoteServiceMixin: def prompt_template(self, system_message: str, prompt: str): if "<|BOT|>" in prompt: _, right = prompt.split("<|BOT|>", 1) diff --git a/src/talemate/client/runpod.py b/src/talemate/client/runpod.py index 9cd64a21..b3109f8e 100644 --- a/src/talemate/client/runpod.py +++ b/src/talemate/client/runpod.py @@ -4,8 +4,6 @@ connection for the pod. This is a simple wrapper around the runpod module. """ import asyncio -import json -import os import dotenv import runpod diff --git a/src/talemate/client/system_prompts.py b/src/talemate/client/system_prompts.py index 956866fa..bfee700c 100644 --- a/src/talemate/client/system_prompts.py +++ b/src/talemate/client/system_prompts.py @@ -27,7 +27,6 @@ PROMPT_TEMPLATE_MAP = { "world_state": "world_state.system-analyst-no-decensor", "summarize": "summarizer.system-no-decensor", "visualize": "visual.system-no-decensor", - # contains some minor attempts at keeping the LLM from generating # refusals to generate certain types of content "roleplay_decensor": "conversation.system", @@ -42,32 +41,39 @@ PROMPT_TEMPLATE_MAP = { "visualize_decensor": "visual.system", } + def cache_all() -> dict: for key in PROMPT_TEMPLATE_MAP: render_prompt(key) return RENDER_CACHE.copy() -def render_prompt(kind:str, decensor:bool=False): + +def render_prompt(kind: str, decensor: bool = False): # work around circular import issue # TODO: refactor to avoid circular import from talemate.prompts import Prompt - + if kind not in PROMPT_TEMPLATE_MAP: - log.warning(f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}") + log.warning( + f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}" + ) return "" - + if decensor: key = f"{kind}_decensor" else: key = kind - + if key not in PROMPT_TEMPLATE_MAP: - log.warning(f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}", key=key) + log.warning( + f"Invalid prompt system prompt identifier: {kind} - decensor: {decensor}", + key=key, + ) return "" - + if key in RENDER_CACHE: return RENDER_CACHE[key] - + prompt = str(Prompt.get(PROMPT_TEMPLATE_MAP[key])) RENDER_CACHE[key] = prompt @@ -77,17 +83,17 @@ def render_prompt(kind:str, decensor:bool=False): class SystemPrompts(pydantic.BaseModel): """ System prompts and a normalized the way to access them. - + Allows specification of a parent "SystemPrompts" instance that will be used as a fallback, and if not so specified, will default to the system prompts in the globals via lambda functions that render the templates. - + The globals that exist now will be deprecated in favor of this later. """ - + parent: "SystemPrompts | None" = pydantic.Field(default=None, exclude=True) - + roleplay: str | None = None narrator: str | None = None creator: str | None = None @@ -98,7 +104,7 @@ class SystemPrompts(pydantic.BaseModel): world_state: str | None = None summarize: str | None = None visualize: str | None = None - + roleplay_decensor: str | None = None narrator_decensor: str | None = None creator_decensor: str | None = None @@ -109,64 +115,61 @@ class SystemPrompts(pydantic.BaseModel): world_state_decensor: str | None = None summarize_decensor: str | None = None visualize_decensor: str | None = None - + class Config: exclude_none = True exclude_unset = True - + @property def defaults(self) -> dict: return RENDER_CACHE.copy() - - def alias(self, alias:str) -> str: - + + def alias(self, alias: str) -> str: if alias in PROMPT_TEMPLATE_MAP: return alias if "narrate" in alias: return "narrator" - + if "direction" in alias or "director" in alias: return "director" - + if "create" in alias or "creative" in alias: return "creator" - + if "conversation" in alias or "roleplay" in alias: return "roleplay" - + if "basic" in alias: return "basic" - + if "edit" in alias: return "editor" - + if "world_state" in alias: return "world_state" - + if "analyze_freeform" in alias or "investigate" in alias: return "analyst_freeform" - + if "analyze" in alias or "analyst" in alias or "analytical" in alias: return "analyst" - + if "summarize" in alias or "summarization" in alias: return "summarize" - + if "visual" in alias: return "visualize" - + return alias - - - def get(self, kind:str, decensor:bool=False) -> str: - + + def get(self, kind: str, decensor: bool = False) -> str: kind = self.alias(kind) - + key = f"{kind}_decensor" if decensor else kind - + if getattr(self, key): return getattr(self, key) if self.parent is not None: return self.parent.get(kind, decensor) - return render_prompt(kind, decensor) \ No newline at end of file + return render_prompt(kind, decensor) diff --git a/src/talemate/client/tabbyapi.py b/src/talemate/client/tabbyapi.py index 2a778fc3..65bc910e 100644 --- a/src/talemate/client/tabbyapi.py +++ b/src/talemate/client/tabbyapi.py @@ -1,5 +1,4 @@ import random -from typing import Literal import json import httpx import pydantic @@ -103,7 +102,6 @@ class TabbyAPIClient(ClientBase): ) def prompt_template(self, system_message: str, prompt: str): - log.debug( "IS API HANDLING PROMPT TEMPLATE", api_handles_prompt_template=self.api_handles_prompt_template, @@ -196,22 +194,18 @@ class TabbyAPIClient(ClientBase): async with httpx.AsyncClient() as client: async with client.stream( - "POST", - url, - headers=headers, - json=payload, - timeout=120.0 + "POST", url, headers=headers, json=payload, timeout=120.0 ) as response: async for chunk in response.aiter_text(): buffer += chunk while True: - line_end = buffer.find('\n') + line_end = buffer.find("\n") if line_end == -1: break line = buffer[:line_end].strip() - buffer = buffer[line_end + 1:] + buffer = buffer[line_end + 1 :] if not line: continue @@ -226,7 +220,7 @@ class TabbyAPIClient(ClientBase): choice = data_obj.get("choices", [{}])[0] - # Chat completions use delta -> content. + # Chat completions use delta -> content. delta = choice.get("delta", {}) content = ( delta.get("content") @@ -235,12 +229,16 @@ class TabbyAPIClient(ClientBase): ) usage = data_obj.get("usage", {}) - completion_tokens = usage.get("completion_tokens", 0) + completion_tokens = usage.get( + "completion_tokens", 0 + ) prompt_tokens = usage.get("prompt_tokens", 0) if content: response_text += content - self.update_request_tokens(self.count_tokens(content)) + self.update_request_tokens( + self.count_tokens(content) + ) except json.JSONDecodeError: # ignore malformed json chunks pass @@ -251,7 +249,9 @@ class TabbyAPIClient(ClientBase): if is_chat: # Process indirect coercion - response_text = self.process_response_for_indirect_coercion(prompt, response_text) + response_text = self.process_response_for_indirect_coercion( + prompt, response_text + ) return response_text @@ -265,7 +265,9 @@ class TabbyAPIClient(ClientBase): return "" except Exception as e: self.log.error("generate error", e=e) - emit("status", message="Error during generation (check logs)", status="error") + emit( + "status", message="Error during generation (check logs)", status="error" + ) return "" def reconfigure(self, **kwargs): @@ -287,7 +289,7 @@ class TabbyAPIClient(ClientBase): self.double_coercion = kwargs["double_coercion"] self._reconfigure_common_parameters(**kwargs) - + self.set_client(**kwargs) def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict: diff --git a/src/talemate/client/textgenwebui.py b/src/talemate/client/textgenwebui.py index eb89a681..94c0a053 100644 --- a/src/talemate/client/textgenwebui.py +++ b/src/talemate/client/textgenwebui.py @@ -8,7 +8,7 @@ import httpx import structlog from openai import AsyncOpenAI -from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField +from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults from talemate.client.registry import register log = structlog.get_logger("talemate.client.textgenwebui") @@ -103,7 +103,6 @@ class TextGeneratorWebuiClient(ClientBase): self.client = AsyncOpenAI(base_url=self.api_url + "/v1", api_key="sk-1111") def finalize_llama3(self, parameters: dict, prompt: str) -> tuple[str, bool]: - if "<|eot_id|>" not in prompt: return prompt, False @@ -122,7 +121,6 @@ class TextGeneratorWebuiClient(ClientBase): return prompt, True def finalize_YI(self, parameters: dict, prompt: str) -> tuple[str, bool]: - if not self.model_name: return prompt, False @@ -141,7 +139,6 @@ class TextGeneratorWebuiClient(ClientBase): return prompt, True async def get_model_name(self): - async with httpx.AsyncClient() as client: response = await client.get( f"{self.api_url}/v1/internal/model/info", @@ -170,14 +167,16 @@ class TextGeneratorWebuiClient(ClientBase): async def generate(self, prompt: str, parameters: dict, kind: str): loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self._generate, prompt, parameters, kind) + return await loop.run_in_executor( + None, self._generate, prompt, parameters, kind + ) def _generate(self, prompt: str, parameters: dict, kind: str): """ Generates text from the given prompt and parameters. """ parameters["prompt"] = prompt.strip(" ") - + response = "" parameters["stream"] = True stream_response = requests.post( @@ -188,17 +187,16 @@ class TextGeneratorWebuiClient(ClientBase): stream=True, ) stream_response.raise_for_status() - + sse = sseclient.SSEClient(stream_response) - + for event in sse.events(): payload = json.loads(event.data) - chunk = payload['choices'][0]['text'] + chunk = payload["choices"][0]["text"] response += chunk self.update_request_tokens(self.count_tokens(chunk)) - - return response + return response def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict: """ diff --git a/src/talemate/client/utils.py b/src/talemate/client/utils.py index f6614ae4..25e0fb5e 100644 --- a/src/talemate/client/utils.py +++ b/src/talemate/client/utils.py @@ -1,5 +1,3 @@ -import copy -import random from urllib.parse import urljoin as _urljoin __all__ = ["urljoin"] diff --git a/src/talemate/commands/__init__.py b/src/talemate/commands/__init__.py index 65439f5d..4465f439 100644 --- a/src/talemate/commands/__init__.py +++ b/src/talemate/commands/__init__.py @@ -1,14 +1,27 @@ -from .base import TalemateCommand -from .cmd_characters import * -from .cmd_debug_tools import * -from .cmd_rebuild_archive import CmdRebuildArchive -from .cmd_rename import CmdRename -from .cmd_regenerate import * -from .cmd_reset import CmdReset -from .cmd_save import CmdSave -from .cmd_save_as import CmdSaveAs -from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene -from .cmd_time_util import * -from .cmd_tts import * -from .cmd_world_state import * -from .manager import Manager +from .base import TalemateCommand # noqa: F401 +from .cmd_characters import CmdActivateCharacter, CmdDeactivateCharacter # noqa: F401 +from .cmd_debug_tools import ( + CmdPromptChangeSectioning, # noqa: F401 + CmdSummarizerUpdateLayeredHistory, # noqa: F401 + CmdSummarizerResetLayeredHistory, # noqa: F401 + CmdSummarizerContextInvestigation, # noqa: F401 +) +from .cmd_rebuild_archive import CmdRebuildArchive # noqa: F401 +from .cmd_rename import CmdRename # noqa: F401 +from .cmd_regenerate import CmdRegenerate # noqa: F401 +from .cmd_reset import CmdReset # noqa: F401 +from .cmd_save import CmdSave # noqa: F401 +from .cmd_save_as import CmdSaveAs # noqa: F401 +from .cmd_setenv import CmdSetEnvironmentToCreative, CmdSetEnvironmentToScene # noqa: F401 +from .cmd_time_util import CmdAdvanceTime # noqa: F401 +from .cmd_tts import CmdTestTTS # noqa: F401 +from .cmd_world_state import ( + CmdAddReinforcement, # noqa: F401 + CmdApplyWorldStateTemplate, # noqa: F401 + CmdCheckPinConditions, # noqa: F401 + CmdDetermineCharacterDevelopment, # noqa: F401 + CmdRemoveReinforcement, # noqa: F401 + CmdSummarizeAndPin, # noqa: F401 + CmdUpdateReinforcements, # noqa: F401 +) +from .manager import Manager # noqa: F401 diff --git a/src/talemate/commands/cmd_debug_tools.py b/src/talemate/commands/cmd_debug_tools.py index c6b90550..1853e1cb 100644 --- a/src/talemate/commands/cmd_debug_tools.py +++ b/src/talemate/commands/cmd_debug_tools.py @@ -1,6 +1,4 @@ import asyncio -import json -import logging import structlog @@ -19,7 +17,6 @@ __all__ = [ log = structlog.get_logger("talemate.commands.cmd_debug_tools") - @register class CmdPromptChangeSectioning(TalemateCommand): """ @@ -95,7 +92,8 @@ class CmdSummarizerUpdateLayeredHistory(TalemateCommand): summarizer = get_agent("summarizer") await summarizer.summarize_to_layered_history() - + + @register class CmdSummarizerResetLayeredHistory(TalemateCommand): """ @@ -108,16 +106,17 @@ class CmdSummarizerResetLayeredHistory(TalemateCommand): async def run(self): summarizer = get_agent("summarizer") - + # if arg is provided remove the last n layers if self.args: n = int(self.args[0]) self.scene.layered_history = self.scene.layered_history[:-n] else: self.scene.layered_history = [] - + await summarizer.summarize_to_layered_history() - + + @register class CmdSummarizerContextInvestigation(TalemateCommand): """ @@ -135,11 +134,10 @@ class CmdSummarizerContextInvestigation(TalemateCommand): if not self.args: self.emit("system", "You must specify a query") return - + await summarizer.request_context_investigations(self.args[0], max_calls=1) - - - + + @register class CmdMemoryCompareStrings(TalemateCommand): """ @@ -156,13 +154,14 @@ class CmdMemoryCompareStrings(TalemateCommand): if not self.args: self.emit("system", "You must specify two strings to compare") return - + string1 = self.args[0] string2 = self.args[1] result = await memory.compare_strings(string1, string2) self.emit("system", f"The strings are {result['cosine_similarity']} similar") - + + @register class CmdRunEditorRevision(TalemateCommand): """ @@ -172,14 +171,13 @@ class CmdRunEditorRevision(TalemateCommand): name = "run_editor_revision" description = "Run the editor revision" aliases = ["run_revision"] - + async def run(self): editor = get_agent("editor") scene = self.scene - + last_message = scene.history[-1] - + result = await editor.revision_detect_bad_prose(str(last_message)) - + self.emit("system", f"Result: {result}") - diff --git a/src/talemate/commands/cmd_rebuild_archive.py b/src/talemate/commands/cmd_rebuild_archive.py index 0bf04833..0182c16f 100644 --- a/src/talemate/commands/cmd_rebuild_archive.py +++ b/src/talemate/commands/cmd_rebuild_archive.py @@ -1,5 +1,3 @@ -import asyncio - from talemate.commands.base import TalemateCommand from talemate.commands.manager import register from talemate.emit import emit @@ -29,7 +27,7 @@ class CmdRebuildArchive(TalemateCommand): ] self.scene.ts = "PT0S" - + memory.delete({"typ": "history"}) entries = 0 @@ -42,9 +40,9 @@ class CmdRebuildArchive(TalemateCommand): ) more = await summarizer.agent.build_archive(self.scene) self.scene.sync_time() - + entries += 1 - + if not more: break diff --git a/src/talemate/commands/cmd_reset.py b/src/talemate/commands/cmd_reset.py index 1cb505ec..0d6af779 100644 --- a/src/talemate/commands/cmd_reset.py +++ b/src/talemate/commands/cmd_reset.py @@ -1,6 +1,6 @@ from talemate.commands.base import TalemateCommand from talemate.commands.manager import register -from talemate.emit import emit, wait_for_input, wait_for_input_yesno +from talemate.emit import wait_for_input_yesno from talemate.exceptions import ResetScene diff --git a/src/talemate/commands/cmd_save_as.py b/src/talemate/commands/cmd_save_as.py index bd227482..53c7479d 100644 --- a/src/talemate/commands/cmd_save_as.py +++ b/src/talemate/commands/cmd_save_as.py @@ -1,5 +1,3 @@ -import asyncio - from talemate.commands.base import TalemateCommand from talemate.commands.manager import register diff --git a/src/talemate/commands/cmd_setenv.py b/src/talemate/commands/cmd_setenv.py index ce903c5a..b7f14a30 100644 --- a/src/talemate/commands/cmd_setenv.py +++ b/src/talemate/commands/cmd_setenv.py @@ -22,10 +22,13 @@ class CmdSetEnvironmentToScene(TalemateCommand): player_character = self.scene.get_player_character() if not player_character: - self.system_message("No characters found - cannot switch to gameplay mode.", meta={ - "icon": "mdi-alert", - "color": "warning", - }) + self.system_message( + "No characters found - cannot switch to gameplay mode.", + meta={ + "icon": "mdi-alert", + "color": "warning", + }, + ) return True self.scene.set_environment("scene") diff --git a/src/talemate/commands/cmd_time_util.py b/src/talemate/commands/cmd_time_util.py index 3fe62c35..2f8ea151 100644 --- a/src/talemate/commands/cmd_time_util.py +++ b/src/talemate/commands/cmd_time_util.py @@ -2,11 +2,6 @@ Commands to manage scene timescale """ -import asyncio -import logging - -import isodate - import talemate.instance as instance from talemate.commands.base import TalemateCommand from talemate.commands.manager import register diff --git a/src/talemate/commands/cmd_tts.py b/src/talemate/commands/cmd_tts.py index d16ba923..87851f79 100644 --- a/src/talemate/commands/cmd_tts.py +++ b/src/talemate/commands/cmd_tts.py @@ -1,10 +1,6 @@ -import asyncio -import logging - from talemate.commands.base import TalemateCommand from talemate.commands.manager import register from talemate.instance import get_agent -from talemate.prompts.base import set_default_sectioning_handler __all__ = [ "CmdTestTTS", diff --git a/src/talemate/commands/cmd_world_state.py b/src/talemate/commands/cmd_world_state.py index 729d2ca4..8e9251f6 100644 --- a/src/talemate/commands/cmd_world_state.py +++ b/src/talemate/commands/cmd_world_state.py @@ -109,8 +109,6 @@ class CmdUpdateReinforcements(TalemateCommand): aliases = ["ws_ur"] async def run(self): - scene = self.scene - world_state = get_agent("world_state") await world_state.update_reinforcements(force=True) @@ -234,18 +232,16 @@ class CmdDetermineCharacterDevelopment(TalemateCommand): scene = self.scene world_state = get_agent("world_state") - creator = get_agent("creator") if not len(self.args): raise ValueError("No character name provided.") character_name = self.args[0] - + character = scene.get_character(character_name) - + if not character: raise ValueError(f"Character {character_name} not found.") - instructions = await world_state.determine_character_development(character) - - # updates = await creator.update_character_sheet(character, instructions) \ No newline at end of file + await world_state.determine_character_development(character) + # updates = await creator.update_character_sheet(character, instructions) diff --git a/src/talemate/commands/manager.py b/src/talemate/commands/manager.py index 3b3192b6..a8306248 100644 --- a/src/talemate/commands/manager.py +++ b/src/talemate/commands/manager.py @@ -36,7 +36,7 @@ class Manager(Emitter): aliases[alias] = name.replace("cmd_", "") return aliases - async def execute(self, cmd, emit_on_unknown:bool = True, state:dict = None): + async def execute(self, cmd, emit_on_unknown: bool = True, state: dict = None): # commands start with ! and are followed by a command name cmd = cmd.strip() cmd_args = "" @@ -56,7 +56,6 @@ class Manager(Emitter): for command_cls in self.command_classes: if command_cls.is_command(cmd_name): - if command_cls.argument_cls: cmd_kwargs = json.loads(cmd_args_unsplit) cmd_args = [] diff --git a/src/talemate/config.py b/src/talemate/config.py index b3bd4258..df2db8d5 100644 --- a/src/talemate/config.py +++ b/src/talemate/config.py @@ -1,4 +1,3 @@ -import copy import datetime import os from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union, Literal @@ -6,8 +5,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union, import pydantic import structlog import yaml -from enum import Enum -from pydantic import BaseModel, Field +from pydantic import BaseModel from typing_extensions import Annotated from talemate.agents.registry import get_agent_class @@ -43,7 +41,7 @@ class Client(BaseModel): rate_limit: Union[int, None] = None data_format: Literal["json", "yaml"] | None = None enabled: bool = True - + system_prompts: SystemPrompts = SystemPrompts() preset_group: str | None = None @@ -165,6 +163,7 @@ class DeepSeekConfig(BaseModel): class OpenRouterConfig(BaseModel): api_key: Union[str, None] = None + class RunPodConfig(BaseModel): api_key: Union[str, None] = None @@ -202,6 +201,7 @@ class RecentScene(BaseModel): date: str cover_image: Union[Asset, None] = None + class EmbeddingFunctionPreset(BaseModel): embeddings: str = "sentence-transformer" model: str = "all-MiniLM-L6-v2" @@ -217,12 +217,11 @@ class EmbeddingFunctionPreset(BaseModel): client: str | None = None - def generate_chromadb_presets() -> dict[str, EmbeddingFunctionPreset]: """ Returns a dict of default embedding presets """ - + return { "default": EmbeddingFunctionPreset(), "Alibaba-NLP/gte-base-en-v1.5": EmbeddingFunctionPreset( @@ -274,25 +273,24 @@ class InferenceParameters(BaseModel): frequency_penalty: float | None = 0.05 repetition_penalty: float | None = 1.0 repetition_penalty_range: int | None = 1024 - + xtc_threshold: float | None = 0.1 xtc_probability: float | None = 0.0 - + dry_multiplier: float | None = 0.0 dry_base: float | None = 1.75 dry_allowed_length: int | None = 2 dry_sequence_breakers: str | None = '"\\n", ":", "\\"", "*"' - + smoothing_factor: float | None = 0.0 smoothing_curve: float | None = 1.0 - + # this determines whether or not it should be persisted # to the config file changed: bool = False class InferencePresets(BaseModel): - analytical: InferenceParameters = InferenceParameters( temperature=0.7, presence_penalty=0, @@ -325,19 +323,26 @@ class InferencePresets(BaseModel): presence_penalty=0.0, ) + class InferencePresetGroup(BaseModel): name: str presets: InferencePresets + class Presets(BaseModel): inference_defaults: InferencePresets = InferencePresets() inference: InferencePresets = InferencePresets() - - inference_groups: dict[str, InferencePresetGroup] = pydantic.Field(default_factory=dict) - - embeddings_defaults: dict[str, EmbeddingFunctionPreset] = pydantic.Field(default_factory=generate_chromadb_presets) - embeddings: dict[str, EmbeddingFunctionPreset] = pydantic.Field(default_factory=generate_chromadb_presets) + inference_groups: dict[str, InferencePresetGroup] = pydantic.Field( + default_factory=dict + ) + + embeddings_defaults: dict[str, EmbeddingFunctionPreset] = pydantic.Field( + default_factory=generate_chromadb_presets + ) + embeddings: dict[str, EmbeddingFunctionPreset] = pydantic.Field( + default_factory=generate_chromadb_presets + ) def gnerate_intro_scenes(): @@ -471,28 +476,28 @@ AnnotatedClient = Annotated[ class HistoryMessageStyle(BaseModel): italic: bool = False bold: bool = False - + # Leave None for default color - color: str | None = None + color: str | None = None class HidableHistoryMessageStyle(HistoryMessageStyle): # certain messages can be hidden, but all messages are shown by default show: bool = True - + class SceneAppearance(BaseModel): - narrator_messages: HistoryMessageStyle = HistoryMessageStyle(italic=True) character_messages: HistoryMessageStyle = HistoryMessageStyle() director_messages: HidableHistoryMessageStyle = HidableHistoryMessageStyle() time_messages: HistoryMessageStyle = HistoryMessageStyle() - context_investigation_messages: HidableHistoryMessageStyle = HidableHistoryMessageStyle() - + context_investigation_messages: HidableHistoryMessageStyle = ( + HidableHistoryMessageStyle() + ) + + class Appearance(BaseModel): - scene: SceneAppearance = SceneAppearance() - class Config(BaseModel): @@ -505,7 +510,7 @@ class Config(BaseModel): creator: CreatorConfig = CreatorConfig() openai: OpenAIConfig = OpenAIConfig() - + deepseek: DeepSeekConfig = DeepSeekConfig() mistralai: MistralAIConfig = MistralAIConfig() @@ -531,11 +536,11 @@ class Config(BaseModel): recent_scenes: RecentScenes = RecentScenes() presets: Presets = Presets() - + appearance: Appearance = Appearance() - + system_prompts: SystemPrompts = SystemPrompts() - + class Config: extra = "ignore" @@ -595,18 +600,18 @@ def save_config(config, file_path: str = "./config.yaml"): # we dont want to persist the following, so we drop them: # - presets.inference_defaults # - presets.embeddings_defaults - + if "inference_defaults" in config["presets"]: config["presets"].pop("inference_defaults") - + if "embeddings_defaults" in config["presets"]: config["presets"].pop("embeddings_defaults") - + # for normal presets we only want to persist if they have changed for preset_name, preset in list(config["presets"]["inference"].items()): if not preset.get("changed"): config["presets"]["inference"].pop(preset_name) - + # in inference groups also only keep if changed for group_name, group in list(config["presets"]["inference_groups"].items()): for preset_name, preset in list(group["presets"].items()): @@ -616,7 +621,7 @@ def save_config(config, file_path: str = "./config.yaml"): # if presets is empty, remove it if not config["presets"]["inference"]: config["presets"].pop("inference") - + # if system_prompts is empty, remove it if not config["system_prompts"]: config.pop("system_prompts") @@ -624,12 +629,13 @@ def save_config(config, file_path: str = "./config.yaml"): # set any client preset_group to "" if it references an # entry that no longer exists in inference_groups for client in config["clients"].values(): - if not client.get("preset_group"): continue - + if client["preset_group"] not in config["presets"].get("inference_groups", {}): - log.warning(f"Client {client['name']} references non-existent preset group {client['preset_group']}, setting to default") + log.warning( + f"Client {client['name']} references non-existent preset group {client['preset_group']}, setting to default" + ) client["preset_group"] = "" with open(file_path, "w") as file: @@ -639,7 +645,6 @@ def save_config(config, file_path: str = "./config.yaml"): def cleanup() -> Config: - log.info("cleaning up config") config = load_config(as_model=True) diff --git a/src/talemate/context.py b/src/talemate/context.py index 39c5a202..514828e4 100644 --- a/src/talemate/context.py +++ b/src/talemate/context.py @@ -35,11 +35,12 @@ regeneration_context = ContextVar("regeneration_context", default=None) active_scene = ContextVar("active_scene", default=None) interaction = ContextVar("interaction", default=InteractionState()) + def handle_generation_cancelled(exc: GenerationCancelled): # set cancel_requested to False on the active_scene - + scene = active_scene.get() - + if scene: scene.cancel_requested = False @@ -102,6 +103,6 @@ class Interaction: def assert_active_scene(scene: object): if not active_scene.get(): raise SceneInactiveError("Scene is not active") - + if active_scene.get() != scene: - raise SceneInactiveError("Scene has changed") \ No newline at end of file + raise SceneInactiveError("Scene has changed") diff --git a/src/talemate/emit/__init__.py b/src/talemate/emit/__init__.py index 17399c1a..3be0fd14 100644 --- a/src/talemate/emit/__init__.py +++ b/src/talemate/emit/__init__.py @@ -1,12 +1,12 @@ -import talemate.emit.signals as signals +import talemate.emit.signals as signals # noqa: F401 from .base import ( - AbortCommand, - Emission, - Emitter, - Receiver, - abort_wait_for_input, - emit, - wait_for_input, - wait_for_input_yesno, -) \ No newline at end of file + AbortCommand, # noqa: F401 + Emission, # noqa: F401 + Emitter, # noqa: F401 + Receiver, # noqa: F401 + abort_wait_for_input, # noqa: F401 + emit, # noqa: F401 + wait_for_input, # noqa: F401 + wait_for_input_yesno, # noqa: F401 +) diff --git a/src/talemate/emit/async_signals.py b/src/talemate/emit/async_signals.py index 7461b32f..ef567183 100644 --- a/src/talemate/emit/async_signals.py +++ b/src/talemate/emit/async_signals.py @@ -6,6 +6,7 @@ __all__ = [ handlers = {} + class AsyncSignal: def __init__(self, name): self.receivers = [] @@ -21,11 +22,11 @@ class AsyncSignal: self.receivers.remove(handler) except ValueError: pass - + async def send(self, emission): for receiver in self.receivers: await receiver(emission) - + def _register(name: str): """ diff --git a/src/talemate/emit/base.py b/src/talemate/emit/base.py index 31097cf0..901cbd19 100644 --- a/src/talemate/emit/base.py +++ b/src/talemate/emit/base.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio import dataclasses -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Callable import structlog @@ -123,16 +123,16 @@ async def wait_for_input( while input_received["message"] is None: await asyncio.sleep(sleep_time) - + interaction_state = interaction.get() - + if abort_condition and (await abort_condition()): raise AbortWaitForInput() - + if interaction_state.reset_requested: interaction_state.reset_requested = False raise RestartSceneLoop() - + if interaction_state.input: input_received["message"] = interaction_state.input input_received["interaction"] = interaction_state @@ -140,7 +140,7 @@ async def wait_for_input( interaction_state.input = None interaction_state.from_choice = None break - + handlers["receive_input"].disconnect(input_receiver) if input_received["message"] == "!abort": @@ -194,6 +194,6 @@ class Emitter: def player_message(self, message: str, character: Character): self.emit("player", message, character=character) - + def context_investigation_message(self, message: str): self.emit("context_investigation", message) diff --git a/src/talemate/events.py b/src/talemate/events.py index cee42277..746d3112 100644 --- a/src/talemate/events.py +++ b/src/talemate/events.py @@ -31,6 +31,7 @@ class ArchiveEvent(Event): memory_id: str ts: str = None + @dataclass class CharacterStateEvent(Event): state: str @@ -62,11 +63,13 @@ class GameLoopActorIterEvent(GameLoopBase): actor: Actor game_loop: GameLoopEvent + @dataclass class GameLoopCharacterIterEvent(GameLoopBase): character: Character game_loop: GameLoopEvent + @dataclass class GameLoopNewMessageEvent(GameLoopBase): message: SceneMessage @@ -76,6 +79,7 @@ class GameLoopNewMessageEvent(GameLoopBase): class PlayerTurnStartEvent(Event): pass + @dataclass class RegenerateGeneration(Event): message: "SceneMessage" @@ -89,4 +93,4 @@ async_signals.register( "regenerate.msg.context_investigation", "game_loop_player_character_iter", "game_loop_ai_character_iter", -) \ No newline at end of file +) diff --git a/src/talemate/exceptions.py b/src/talemate/exceptions.py index f7b735a4..adc10ef8 100644 --- a/src/talemate/exceptions.py +++ b/src/talemate/exceptions.py @@ -33,11 +33,13 @@ class ResetScene(TalemateInterrupt): pass + class GenerationCancelled(TalemateInterrupt): """ - Interrupt current scene and return action to the user + Interrupt current scene and return action to the user """ - pass + + pass class RenderPromptError(TalemateError): @@ -76,19 +78,21 @@ class UnknownDataSpec(TalemateError): pass + class ActedAsCharacter(Exception): """ Raised when the user acts as another character than the main player character """ - def __init__(self, character_name:str): + def __init__(self, character_name: str): self.character_name = character_name super().__init__(f"Acted as character: {character_name}") - - + + class AbortCommand(IOError): pass + class AbortWaitForInput(IOError): - pass \ No newline at end of file + pass diff --git a/src/talemate/files.py b/src/talemate/files.py index e3be5516..01ef0adb 100644 --- a/src/talemate/files.py +++ b/src/talemate/files.py @@ -1,8 +1,6 @@ import fnmatch import os -from talemate.config import load_config - def list_scenes_directory(path: str = ".") -> list: """ @@ -10,8 +8,6 @@ def list_scenes_directory(path: str = ".") -> list: :param directory: Directory to list scene files from. :return: List of scene files in the given directory. """ - config = load_config() - current_dir = os.getcwd() scenes = _list_files_and_directories(os.path.join(current_dir, "scenes"), path) @@ -36,9 +32,9 @@ def _list_files_and_directories(root: str, path: str) -> list: # Check each file if it matches any of the patterns for filename in filenames: # Skip JSON files inside 'nodes' directories - if filename.endswith('.json') and 'nodes' in dirpath.split(os.sep): + if filename.endswith(".json") and "nodes" in dirpath.split(os.sep): continue - + # Get the relative file path rel_path = os.path.relpath(dirpath, root) for pattern in patterns: diff --git a/src/talemate/game/engine/__init__.py b/src/talemate/game/engine/__init__.py index bc15e46d..25636cba 100644 --- a/src/talemate/game/engine/__init__.py +++ b/src/talemate/game/engine/__init__.py @@ -19,16 +19,18 @@ nest_asyncio.apply() DEV_MODE = True + def empty_function(*args, **kwargs): pass -def exec_restricted(code: str, filename:str, **kwargs): + +def exec_restricted(code: str, filename: str, **kwargs): compiled_code = compile_restricted(code, filename=filename, mode="exec") - + # Create a restricted globals dictionary restricted_globals = safe_globals.copy() safe_locals = {} - + # Add custom variables, functions, or objects to the restricted globals restricted_globals.update(kwargs) restricted_globals["__name__"] = "__main__" @@ -44,6 +46,7 @@ def exec_restricted(code: str, filename:str, **kwargs): # Execute the compiled code with the restricted globals return exec(compiled_code, restricted_globals, safe_locals) + def compile_scene_module(module_code: str, **kwargs) -> dict[str, callable]: # Compile the module code using RestrictedPython compiled_code = compile_restricted( @@ -71,7 +74,9 @@ def compile_scene_module(module_code: str, **kwargs) -> dict[str, callable]: return { "game": safe_locals.get("game"), - "on_generation_cancelled": safe_locals.get("on_generation_cancelled", empty_function) + "on_generation_cancelled": safe_locals.get( + "on_generation_cancelled", empty_function + ), } @@ -180,18 +185,22 @@ class GameInstructionsMixin: # read thje file into _module property with open(module_path, "r") as f: module_code = f.read() - + scene_modules = compile_scene_module(module_code) - + if "game" not in scene_modules: - raise ValueError(f"`game` function not found in scene module {module_path}") - + raise ValueError( + f"`game` function not found in scene module {module_path}" + ) + scene._module = GameInstructionScope( director=self, log=log, scene=scene, module_function=scene_modules["game"], - on_generation_cancelled=scene_modules.get("on_generation_cancelled", empty_function) + on_generation_cancelled=scene_modules.get( + "on_generation_cancelled", empty_function + ), ) async def scene_has_module(self, scene: "Scene"): diff --git a/src/talemate/game/engine/api/__init__.py b/src/talemate/game/engine/api/__init__.py index 5d54557f..cb65a664 100644 --- a/src/talemate/game/engine/api/__init__.py +++ b/src/talemate/game/engine/api/__init__.py @@ -1,10 +1,10 @@ -import talemate.game.engine.api.agents.creator as agent_creator -import talemate.game.engine.api.agents.director as agent_director -import talemate.game.engine.api.agents.narrator as agent_narrator -import talemate.game.engine.api.agents.visual as agent_visual -import talemate.game.engine.api.agents.world_state as agent_world_state -import talemate.game.engine.api.game_state as game_state -import talemate.game.engine.api.log as log -import talemate.game.engine.api.prompt as prompt -import talemate.game.engine.api.scene as scene -import talemate.game.engine.api.signals as signals +import talemate.game.engine.api.agents.creator as agent_creator # noqa: F401 +import talemate.game.engine.api.agents.director as agent_director # noqa: F401 +import talemate.game.engine.api.agents.narrator as agent_narrator # noqa: F401 +import talemate.game.engine.api.agents.visual as agent_visual # noqa: F401 +import talemate.game.engine.api.agents.world_state as agent_world_state # noqa: F401 +import talemate.game.engine.api.game_state as game_state # noqa: F401 +import talemate.game.engine.api.log as log # noqa: F401 +import talemate.game.engine.api.prompt as prompt # noqa: F401 +import talemate.game.engine.api.scene as scene # noqa: F401 +import talemate.game.engine.api.signals as signals # noqa: F401 diff --git a/src/talemate/game/engine/api/agents/creator.py b/src/talemate/game/engine/api/agents/creator.py index 5ffe0667..22d22076 100644 --- a/src/talemate/game/engine/api/agents/creator.py +++ b/src/talemate/game/engine/api/agents/creator.py @@ -16,7 +16,6 @@ if TYPE_CHECKING: def create(scene: "Scene") -> "ScopedAPI": class API(ScopedAPI): - def determine_content_context_for_description( self, description: str, diff --git a/src/talemate/game/engine/api/agents/director.py b/src/talemate/game/engine/api/agents/director.py index 885544a4..17e011af 100644 --- a/src/talemate/game/engine/api/agents/director.py +++ b/src/talemate/game/engine/api/agents/director.py @@ -18,7 +18,6 @@ __all__ = ["create"] def create(scene: "Scene") -> "ScopedAPI": class API(ScopedAPI): - def persist_character( self, character_name: str, diff --git a/src/talemate/game/engine/api/agents/narrator.py b/src/talemate/game/engine/api/agents/narrator.py index dc83741f..093cddf9 100644 --- a/src/talemate/game/engine/api/agents/narrator.py +++ b/src/talemate/game/engine/api/agents/narrator.py @@ -21,7 +21,6 @@ __all__ = ["create"] def create(scene: "Scene") -> "ScopedAPI": class API(ScopedAPI): - def action_to_narration( self, action_name: str, diff --git a/src/talemate/game/engine/api/agents/visual.py b/src/talemate/game/engine/api/agents/visual.py index be3627ee..25117785 100644 --- a/src/talemate/game/engine/api/agents/visual.py +++ b/src/talemate/game/engine/api/agents/visual.py @@ -17,7 +17,6 @@ __all__ = ["create"] def create(scene: "Scene") -> "ScopedAPI": class API(ScopedAPI): - def generate_character_portrait( self, character_name: str, diff --git a/src/talemate/game/engine/api/agents/world_state.py b/src/talemate/game/engine/api/agents/world_state.py index 3cfbd815..bac7abd9 100644 --- a/src/talemate/game/engine/api/agents/world_state.py +++ b/src/talemate/game/engine/api/agents/world_state.py @@ -17,7 +17,6 @@ __all__ = ["create"] def create(scene: "Scene") -> "ScopedAPI": class API(ScopedAPI): - def activate_character(self, character_name: str): """ Activates a character. diff --git a/src/talemate/game/engine/api/base.py b/src/talemate/game/engine/api/base.py index 22c9585f..9e612e1d 100644 --- a/src/talemate/game/engine/api/base.py +++ b/src/talemate/game/engine/api/base.py @@ -3,7 +3,7 @@ Functions for scene direction and manipulation """ import asyncio -from typing import TYPE_CHECKING, Any, Callable, Coroutine +from typing import Coroutine __all__ = [ "run_async", diff --git a/src/talemate/game/engine/api/game_state.py b/src/talemate/game/engine/api/game_state.py index 1940c0c5..64e2f251 100644 --- a/src/talemate/game/engine/api/game_state.py +++ b/src/talemate/game/engine/api/game_state.py @@ -1,7 +1,5 @@ from typing import TYPE_CHECKING -import pydantic - import talemate.game.engine.api.schema as schema from talemate.game.engine.api.base import ScopedAPI @@ -13,7 +11,6 @@ __all__ = ["create"] def create(game_state: "GameState") -> "ScopedAPI": class API(ScopedAPI): - help_text = """Functions for game state management""" ### Variables diff --git a/src/talemate/game/engine/api/log.py b/src/talemate/game/engine/api/log.py index b66a122a..458abcbb 100644 --- a/src/talemate/game/engine/api/log.py +++ b/src/talemate/game/engine/api/log.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING - import structlog from talemate.game.engine.api.base import ScopedAPI @@ -9,7 +7,6 @@ __all__ = ["create"] def create(log: structlog.BoundLogger) -> "ScopedAPI": class LogAPI(ScopedAPI): - def info(self, event, *args, **kwargs): log.info(event, *args, **kwargs) diff --git a/src/talemate/game/engine/api/prompt.py b/src/talemate/game/engine/api/prompt.py index 1c468a2e..975f7340 100644 --- a/src/talemate/game/engine/api/prompt.py +++ b/src/talemate/game/engine/api/prompt.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: def create(scene: "Scene", client: "ClientBase") -> "ScopedAPI": class API(ScopedAPI): - def request( self, template_name: str, diff --git a/src/talemate/game/engine/api/scene.py b/src/talemate/game/engine/api/scene.py index e54cfbed..ee1b9c05 100644 --- a/src/talemate/game/engine/api/scene.py +++ b/src/talemate/game/engine/api/scene.py @@ -4,7 +4,7 @@ import pydantic import talemate.game.engine.api.schema as schema from talemate.game.engine.api.base import ScopedAPI, run_async -from talemate.game.engine.api.exceptions import SceneInactive, UnknownCharacter +from talemate.game.engine.api.exceptions import UnknownCharacter if TYPE_CHECKING: from talemate.tale_mate import Scene @@ -14,7 +14,6 @@ __all__ = ["create"] def create(scene: "Scene") -> "ScopedAPI": class API(ScopedAPI): - help_text = """Functions for scene direction and manipulation""" @property @@ -93,7 +92,9 @@ def create(scene: "Scene") -> "ScopedAPI": validated = Arguments(budget=budget, keep_director=keep_director) - return scene.context_history(validated.budget, keep_director=validated.keep_director) + return scene.context_history( + validated.budget, keep_director=validated.keep_director + ) def get_player_character(self) -> schema.CharacterSchema | None: """ diff --git a/src/talemate/game/engine/api/schema.py b/src/talemate/game/engine/api/schema.py index 798a1a62..147f9400 100644 --- a/src/talemate/game/engine/api/schema.py +++ b/src/talemate/game/engine/api/schema.py @@ -57,9 +57,9 @@ class CharacterSchema(pydantic.BaseModel): def from_character(cls, character: "Character") -> "CharacterSchema": from talemate.tale_mate import Character - assert isinstance( - character, Character - ), f"Expected Character, got {type(character)}" + assert isinstance(character, Character), ( + f"Expected Character, got {type(character)}" + ) return cls( name=character.name, @@ -103,10 +103,9 @@ class CharacterMessageSchema(pydantic.BaseModel): @classmethod def from_message(cls, message: "CharacterMessage") -> "CharacterMessageSchema": - - assert isinstance( - message, CharacterMessage - ), f"Expected CharacterMessage, got {type(message)}" + assert isinstance(message, CharacterMessage), ( + f"Expected CharacterMessage, got {type(message)}" + ) return cls( message=message.message, @@ -136,10 +135,9 @@ class NarratorMessageSchema(pydantic.BaseModel): @classmethod def from_message(cls, message: "NarratorMessage") -> "NarratorMessageSchema": - - assert isinstance( - message, NarratorMessage - ), f"Expected NarratorMessage, got {type(message)}" + assert isinstance(message, NarratorMessage), ( + f"Expected NarratorMessage, got {type(message)}" + ) return cls( message=message.message, @@ -167,10 +165,9 @@ class DirectorMessageSchema(pydantic.BaseModel): @classmethod def from_message(cls, message: "DirectorMessage") -> "DirectorMessageSchema": - - assert isinstance( - message, DirectorMessage - ), f"Expected DirectorMessage, got {type(message)}" + assert isinstance(message, DirectorMessage), ( + f"Expected DirectorMessage, got {type(message)}" + ) return cls( message=message.message, @@ -203,10 +200,9 @@ class TimePassageMessageSchema(pydantic.BaseModel): @classmethod def from_message(cls, message: "TimePassageMessage") -> "TimePassageMessageSchema": - - assert isinstance( - message, TimePassageMessage - ), f"Expected TimePassageMessage, got {type(message)}" + assert isinstance(message, TimePassageMessage), ( + f"Expected TimePassageMessage, got {type(message)}" + ) return cls( message=message.message, @@ -241,10 +237,9 @@ class ReinforcementMessageSchema(pydantic.BaseModel): def from_message( cls, message: "ReinforcementMessage" ) -> "ReinforcementMessageSchema": - - assert isinstance( - message, ReinforcementMessage - ), f"Expected ReinforcementMessage, got {type(message)}" + assert isinstance(message, ReinforcementMessage), ( + f"Expected ReinforcementMessage, got {type(message)}" + ) return cls( message=message.message, diff --git a/src/talemate/game/engine/api/signals.py b/src/talemate/game/engine/api/signals.py index af0e9d06..ffaa9021 100644 --- a/src/talemate/game/engine/api/signals.py +++ b/src/talemate/game/engine/api/signals.py @@ -13,7 +13,6 @@ __all__ = ["create"] def create() -> "ScopedAPI": class API(ScopedAPI): - def status(self, status: str, message: str, as_scene_message: bool = False): """ Emits a status message to the scene diff --git a/src/talemate/game/engine/nodes/__init__.py b/src/talemate/game/engine/nodes/__init__.py index 2bdb4e69..56d3b479 100644 --- a/src/talemate/game/engine/nodes/__init__.py +++ b/src/talemate/game/engine/nodes/__init__.py @@ -10,10 +10,10 @@ TALEMATE_ROOT = os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", SEARCH_PATHS = [ # third party node module definitions os.path.join(TALEMATE_ROOT, "templates", "modules"), - # agentic node modules os.path.join(TALEMATE_ROOT, "src", "talemate", "agents"), - # game engine node modules - os.path.join(TALEMATE_ROOT, "src", "talemate", "game", "engine", "nodes", "modules"), + os.path.join( + TALEMATE_ROOT, "src", "talemate", "game", "engine", "nodes", "modules" + ), ] diff --git a/src/talemate/game/engine/nodes/agent.py b/src/talemate/game/engine/nodes/agent.py index bd6c0cf4..acbf8f34 100644 --- a/src/talemate/game/engine/nodes/agent.py +++ b/src/talemate/game/engine/nodes/agent.py @@ -3,11 +3,11 @@ import pydantic import structlog from typing import ClassVar from talemate.game.engine.nodes.core import ( - Node, - register, - GraphState, - InputValueError, - PropertyField, + Node, + register, + GraphState, + InputValueError, + PropertyField, NodeVerbosity, NodeStyle, UNRESOLVED, @@ -17,104 +17,115 @@ from talemate.agents.registry import get_agent_types, get_agent_class from talemate.agents.base import Agent, DynamicInstruction as DynamicInstructionType from talemate.instance import get_agent -from .state import ConditionalSetState, ConditionalUnsetState, ConditionalCounterState, StateManipulation, HasState, GetState +from .state import ( + ConditionalSetState, + ConditionalUnsetState, + ConditionalCounterState, + StateManipulation, + HasState, + GetState, +) log = structlog.get_logger("talemate.game.engine.nodes.agent") -TYPE_CHOICES.extend([ - "dynamic_instruction", -]) +TYPE_CHOICES.extend( + [ + "dynamic_instruction", + ] +) + class AgentNode(Node): - - _agent_name:ClassVar[str | None] = None + _agent_name: ClassVar[str | None] = None - @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#392c34", title_color="#572e44", - icon="F1719", #robot-happy + icon="F1719", # robot-happy ) - + @property def agent(self): return get_agent(self._agent_name) - - async def get_agent(self, agent_name:str): + + async def get_agent(self, agent_name: str): return get_agent(agent_name) class AgentSettingsNode(Node): """ Base node to render agent settings. - + Will take an _agent_name class property, then create outputs based on the agents' AgentAction and AgentActionConfig """ - - _agent_name:ClassVar[str | None] = None - + + _agent_name: ClassVar[str | None] = None + def setup(self): agent_cls = get_agent_class(self._agent_name) - + if not agent_cls: - raise InputValueError(self, "_agent_name", f"Could not find agent: {self._agent_name}") - + raise InputValueError( + self, "_agent_name", f"Could not find agent: {self._agent_name}" + ) + self.add_output("agent_enabled", socket_type="bool") - + actions = agent_cls.init_actions() - + for action_name, action in actions.items(): self.add_output(f"{action_name}_enabled", socket_type="bool") - + if not action.config: continue - + for config_name, config in action.config.items(): self.add_output(f"{action_name}_{config_name}", socket_type=config.type) - + async def run(self, state: GraphState): agent = get_agent(self._agent_name) - + if not agent: - raise InputValueError(self, "_agent_name", f"Could not find agent: {self._agent_name}") - - outputs = { - "agent_enabled": agent.enabled - } - + raise InputValueError( + self, "_agent_name", f"Could not find agent: {self._agent_name}" + ) + + outputs = {"agent_enabled": agent.enabled} + for action_name, action in agent.actions.items(): outputs[f"{action_name}_enabled"] = action.enabled - + if not action.config: continue - + for config_name, config in action.config.items(): outputs[f"{action_name}_{config_name}"] = config.value - + self.set_output_values(outputs) + @register("agents/ToggleAgentAction") class ToggleAgentAction(Node): """ Allows disabling or enabling an agent action - + Inputs: - + - agent: str,agent - action_name: str - enabled: bool - + Outputs: - + - agent: agent - action_name: str - enabled: bool """ - + class Fields: agent = PropertyField( name="agent", @@ -122,68 +133,72 @@ class ToggleAgentAction(Node): default="", description="The agent to toggle the action on", choices=[], - generate_choices=lambda: get_agent_types() + generate_choices=lambda: get_agent_types(), ) - + action_name = PropertyField( name="action_name", type="str", default="", - description="The name of the action to toggle" + description="The name of the action to toggle", ) - + enabled = PropertyField( name="enabled", type="bool", default=True, - description="Whether to enable or disable the action" + description="Whether to enable or disable the action", ) - - + def setup(self): self.add_input("state") self.add_input("agent", socket_type="str,agent", optional=True) self.add_input("action_name", socket_type="str", optional=True) self.add_input("enabled", socket_type="bool", optional=True) - + self.set_property("agent", "") self.set_property("action_name", "") self.set_property("enabled", True) - + self.add_output("agent", socket_type="agent") self.add_output("action_name", socket_type="str") self.add_output("enabled", socket_type="bool") - + async def run(self, state: GraphState): agent = self.get_input_value("agent") action_name = self.get_input_value("action_name") enabled = self.get_input_value("enabled") - + if isinstance(agent, str): agent_name = agent agent = get_agent(agent_name) if not agent: - raise InputValueError(self, "agent", f"Could not find agent: {agent_name}") - + raise InputValueError( + self, "agent", f"Could not find agent: {agent_name}" + ) + action = agent.actions.get(action_name) - + if not action: - raise InputValueError(self, "action_name", f"Could not find action {action_name} in agent {agent}") - + raise InputValueError( + self, + "action_name", + f"Could not find action {action_name} in agent {agent}", + ) + action.enabled = enabled - - self.set_output_values({ - "agent": agent, - "action_name": action_name, - "enabled": enabled - }) + + self.set_output_values( + {"agent": agent, "action_name": action_name, "enabled": enabled} + ) + @register("agents/CallAgentFunction") class CallAgentFunction(Node): """ Call an agent function """ - + class Fields: agent = PropertyField( name="agent", @@ -191,65 +206,74 @@ class CallAgentFunction(Node): default="", description="The agent to call the function on", choices=[], - generate_choices=lambda: get_agent_types() + generate_choices=lambda: get_agent_types(), ) - + function_name = PropertyField( name="function_name", type="str", default="", - description="The name of the function to call on the agent" + description="The name of the function to call on the agent", ) - + arguments = PropertyField( name="arguments", type="dict", default={}, - description="The arguments to pass to the function" + description="The arguments to pass to the function", ) - + def __init__(self, title="Call Agent Function", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("agent", socket_type="str,agent") self.add_input("function_name", socket_type="str") self.add_input("arguments", socket_type="dict") - + self.set_property("agent", "") self.set_property("function_name", "") self.set_property("arguments", {}) - + self.add_output("result", socket_type="any") - + async def run(self, state: GraphState): agent = self.get_input_value("agent") function_name = self.get_input_value("function_name") arguments = self.get_input_value("arguments") - + if isinstance(agent, str): agent_name = agent agent = get_agent(agent_name) if not agent: - raise InputValueError(self, "agent", f"Could not find agent: {agent_name}") - + raise InputValueError( + self, "agent", f"Could not find agent: {agent_name}" + ) + function = getattr(agent, function_name, None) - + if not function: - raise InputValueError(self, "function_name", f"Could not find function {function_name} in agent {agent}") - + raise InputValueError( + self, + "function_name", + f"Could not find function {function_name} in agent {agent}", + ) + # is function a coroutine? if inspect.iscoroutinefunction(function): result = await function(**arguments) else: result = function(**arguments) - + if state.verbosity >= NodeVerbosity.VERBOSE: - log.debug(f"Called agent function {function_name} on agent {agent}", result=result, arguments=arguments) - - self.set_output_values({ - "result": result - }) + log.debug( + f"Called agent function {function_name} on agent {agent}", + result=result, + arguments=arguments, + ) + + self.set_output_values({"result": result}) + @register("agents/GetAgent") class GetAgent(Node): @@ -264,10 +288,9 @@ class GetAgent(Node): default="", description="The name of the agent to get the client for", choices=[], - generate_choices=lambda: get_agent_types() + generate_choices=lambda: get_agent_types(), ) - - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: @@ -275,31 +298,32 @@ class GetAgent(Node): node_color="#313150", title_color="#403f71", auto_title="{agent_name}", - icon="F0D3D", #transit-connection-variant + icon="F0D3D", # transit-connection-variant ) - + def __init__(self, title="Get Agent", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("agent_name", "") self.add_output("agent", socket_type="agent") - + async def run(self, state: GraphState): agent_name = self.get_property("agent_name") - + if not agent_name: return - + agent = get_agent(agent_name) - + if not agent: - raise InputValueError(self, "agent_name", f"Could not find agent: {agent_name}") - - self.set_output_values({ - "agent": agent - }) - + raise InputValueError( + self, "agent_name", f"Could not find agent: {agent_name}" + ) + + self.set_output_values({"agent": agent}) + + class AgentStateManipulation(StateManipulation): class Fields: agent = PropertyField( @@ -308,37 +332,39 @@ class AgentStateManipulation(StateManipulation): type="str", default=UNRESOLVED, choices=[], - generate_choices=lambda: get_agent_types() + generate_choices=lambda: get_agent_types(), ) scope = PropertyField( name="scope", description="Which scope to manipulate", type="str", default="scene", - choices=["scene", "context"] + choices=["scene", "context"], ) name = PropertyField( name="name", description="The name of the variable to manipulate", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + def setup(self): super().setup() self.add_input("agent", socket_type="str,agent") self.set_property("agent", UNRESOLVED) - + def get_state_container(self, state: GraphState): scope = self.get_property("scope") - agent:Agent | str = self.get_input_value("agent") - + agent: Agent | str = self.get_input_value("agent") + if isinstance(agent, str): agent_name = agent - agent:Agent | None = get_agent(agent_name) + agent: Agent | None = get_agent(agent_name) if not agent: - raise InputValueError(self, "agent", f"Could not find agent: {agent_name}") - + raise InputValueError( + self, "agent", f"Could not find agent: {agent_name}" + ) + if scope == "scene": try: return agent.scene.agent_state[agent.agent_type] @@ -349,102 +375,106 @@ class AgentStateManipulation(StateManipulation): return agent.dump_context_state() else: raise InputValueError(self, "scope", f"Unknown scope: {scope}") - - + @register("agents/SetAgentState") class SetAgentState(AgentStateManipulation, ConditionalSetState): """ Set an agent state variable - + Provides a required `state` input causing the node to only run when a state is provided """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#2e4657", - icon="F06F4", # download-network - auto_title="SET {agent}.{scope}.{name}" + icon="F06F4", # download-network + auto_title="SET {agent}.{scope}.{name}", ) def __init__(self, title="Set Agent State", **kwargs): super().__init__(title=title, **kwargs) - + + @register("agents/GetAgentState") class GetAgentState(AgentStateManipulation, GetState): """ Get an agent state variable """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#44552f", - icon="F06F6", # upload-network - auto_title="GET {agent}.{scope}.{name}" + icon="F06F6", # upload-network + auto_title="GET {agent}.{scope}.{name}", ) - + def __init__(self, title="Get Agent State", **kwargs): super().__init__(title=title, **kwargs) - + + @register("agents/UnsetAgentState") class UnsetAgentState(AgentStateManipulation, ConditionalUnsetState): """ Unset an agent state variable - + Provides a required `state` input causing the node to only run when a state is provided """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#7f2e2e", - icon="F0AED", # account-remove-outline - auto_title="UNSET {agent}.{scope}.{name}" + icon="F0AED", # account-remove-outline + auto_title="UNSET {agent}.{scope}.{name}", ) - + def __init__(self, title="Unset Agent State", **kwargs): super().__init__(title=title, **kwargs) - + + class HasAgentState(AgentStateManipulation, HasState): """ Check if an agent state variable exists - + Provides a required `state` input causing the node to only run when a state is provided """ - + def __init__(self, title="Has Agent State", **kwargs): super().__init__(title=title, **kwargs) - + + @register("agents/CounterAgentState") class CounterAgentState(AgentStateManipulation, ConditionalCounterState): """ Increment or decrement an agent state variable """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#2e4657", - icon="F0199", # counter - auto_title="COUNT {agent}.{scope}.{name}" + icon="F0199", # counter + auto_title="COUNT {agent}.{scope}.{name}", ) - + def __init__(self, title="Counter Agent State", **kwargs): super().__init__(title=title, **kwargs) - + + @register("agents/DynamicInstruction") class DynamicInstruction(Node): """ Dynamic instruction object to use for instruction injection in event handlers """ - + class Fields: header = PropertyField( name="header", @@ -452,33 +482,37 @@ class DynamicInstruction(Node): type="str", default=UNRESOLVED, ) - + content = PropertyField( name="content", description="The content of the dynamic instruction", type="text", default=UNRESOLVED, ) - + def __init__(self, title="Dynamic Instruction", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("header", socket_type="str", optional=True) self.add_input("content", socket_type="str", optional=True) - + self.set_property("header", UNRESOLVED) self.set_property("content", UNRESOLVED) - + self.add_output("dynamic_instruction", socket_type="dynamic_instruction") - + async def run(self, state: GraphState): header = self.normalized_input_value("header") content = self.normalized_input_value("content") - + if not header or not content: return - - self.set_output_values({ - "dynamic_instruction": DynamicInstructionType(title=header, content=content) - }) \ No newline at end of file + + self.set_output_values( + { + "dynamic_instruction": DynamicInstructionType( + title=header, content=content + ) + } + ) diff --git a/src/talemate/game/engine/nodes/api.py b/src/talemate/game/engine/nodes/api.py index ce4f78b5..5a8aa8a0 100644 --- a/src/talemate/game/engine/nodes/api.py +++ b/src/talemate/game/engine/nodes/api.py @@ -7,29 +7,28 @@ from talemate.game.engine.nodes.core import ( NodeStyle, PropertyField, UNRESOLVED, - PASSTHROUGH_ERRORS, ) from talemate.instance import get_agent from talemate.context import active_scene -from talemate.game.engine.nodes.base_types import base_node_type from talemate.game.engine import exec_restricted from talemate.game.scope import OpenScopedContext, GameInstructionScope log = structlog.get_logger("talemate.game.engine.nodes.core.api") + @register("core/functions/ScopedAPIFunction") class ScopedAPIFunction(Node): """ Executes python code inside the quarantined scoped environment. """ - + class Fields: code = PropertyField( name="code", description="The code to execute", type="text", - default=UNRESOLVED + default=UNRESOLVED, ) @pydantic.computed_field(description="Node style") @@ -37,58 +36,48 @@ class ScopedAPIFunction(Node): def style(self) -> NodeStyle: return NodeStyle( title_color="#306c51", - icon="F10D6", #code-braces-box + icon="F10D6", # code-braces-box ) - def __init__(self, title="Scoped API Function", **kwargs): super().__init__(title=title, **kwargs) - - + def setup(self): self.add_input("state") self.add_input("agent", socket_type="agent") self.add_input("arguments", socket_type="dict", optional=True) - + self.set_property("code", UNRESOLVED) - + self.add_output("result") - + async def run(self, state: GraphState): - scene = active_scene.get() code = self.require_input("code") agent = self.require_input("agent") arguments = self.normalized_input_value("arguments") - + if arguments is None: arguments = {} - + result = {} - - + def exec_scoped_api(scope): - nonlocal result + nonlocal result exec_restricted( - code, - f"<{self.title}>", - arguments=arguments, - result=result, - TM=scope + code, f"<{self.title}>", arguments=arguments, result=result, TM=scope ) - + return result.get("value") - + _module = GameInstructionScope( director=get_agent("director"), log=log, scene=scene, - module_function=lambda s: exec_scoped_api(s) + module_function=lambda s: exec_scoped_api(s), ) - + with OpenScopedContext(scene, agent.client): _module() - - self.set_output_values({ - "result": result - }) \ No newline at end of file + + self.set_output_values({"result": result}) diff --git a/src/talemate/game/engine/nodes/base_types.py b/src/talemate/game/engine/nodes/base_types.py index 257dd82d..7ab60b9b 100644 --- a/src/talemate/game/engine/nodes/base_types.py +++ b/src/talemate/game/engine/nodes/base_types.py @@ -1,19 +1,21 @@ __all__ = [ - 'base_node_type', - 'get_base_type', - 'BASE_TYPES', + "base_node_type", + "get_base_type", + "BASE_TYPES", ] BASE_TYPES = {} -def get_base_type(base_type:str): + +def get_base_type(base_type: str): return BASE_TYPES.get(base_type) + class base_node_type: - def __init__(self, base_type:str): + def __init__(self, base_type: str): self.base_type = base_type - + def __call__(self, cls): cls.base_type = self.base_type BASE_TYPES[self.base_type] = cls - return cls \ No newline at end of file + return cls diff --git a/src/talemate/game/engine/nodes/command.py b/src/talemate/game/engine/nodes/command.py index ca57a435..050831f8 100644 --- a/src/talemate/game/engine/nodes/command.py +++ b/src/talemate/game/engine/nodes/command.py @@ -1,17 +1,14 @@ import structlog from typing import ClassVar -import talemate.emit.async_signals as signals -from .core import Listen, Node, Graph, GraphState, NodeVerbosity, PropertyField, UNRESOLVED, InputValueError +from .core import ( + GraphState, + PropertyField, +) from .run import FunctionWrapper from .base_types import base_node_type -from .registry import register from .run import Function -from talemate.emit import emit, Emission -from talemate.emit.signals import handlers -from talemate.context import active_scene -from talemate.game.engine.api.schema import StatusEnum -__all__ = ["Command", "InitializeCommand"] +__all__ = ["Command"] log = structlog.get_logger("talemate.game.engine.nodes.command") @@ -21,25 +18,20 @@ class Command(Function): """ A command is a node that can be executed by the player """ - + _isolated: ClassVar[bool] = True _export_definition: ClassVar[bool] = False - - + class Fields: name = PropertyField( - name="name", - description="The name of the command", - type="str", - default="" + name="name", description="The name of the command", type="str", default="" ) - - + def __init__(self, title="Command", **kwargs): super().__init__(title=title, **kwargs) if not self.get_property("name"): self.set_property("name", "") - - async def execute_command(self, state:GraphState, **kwargs): + + async def execute_command(self, state: GraphState, **kwargs): wrapped = FunctionWrapper(self, self, state) await wrapped(**kwargs) diff --git a/src/talemate/game/engine/nodes/core/__init__.py b/src/talemate/game/engine/nodes/core/__init__.py index 6a39fc28..352a3380 100644 --- a/src/talemate/game/engine/nodes/core/__init__.py +++ b/src/talemate/game/engine/nodes/core/__init__.py @@ -7,16 +7,20 @@ import asyncio import structlog import traceback import json -import asyncio import time import reprlib -import json import re from enum import IntEnum from talemate.game.engine.nodes.base_types import base_node_type, BASE_TYPES from talemate.game.engine.nodes.registry import get_node, register -from talemate.exceptions import ExitScene, ResetScene, RestartSceneLoop, ActedAsCharacter, GenerationCancelled +from talemate.exceptions import ( + ExitScene, + ResetScene, + RestartSceneLoop, + ActedAsCharacter, + GenerationCancelled, +) import talemate.emit.async_signals as async_signals from talemate.util.async_tools import shared_debounce from talemate.context import active_scene @@ -36,22 +40,24 @@ PYTHON_TYPE_TO_STRING = { "": "None", } -TYPE_CHOICES = sorted([ - "str", - "int", - "float", - "bool", - "list", - "dict", - "any", - "character", - "interaction_state", - "actor", - "event", - "client", - "agent", - "function", -]) +TYPE_CHOICES = sorted( + [ + "str", + "int", + "float", + "bool", + "list", + "dict", + "any", + "character", + "interaction_state", + "actor", + "event", + "client", + "agent", + "function", + ] +) TYPE_TO_CLASS = { "str": str, @@ -63,33 +69,41 @@ TYPE_TO_CLASS = { "any": Any, } + def get_type_class(type_str: str) -> Any: if TYPE_TO_CLASS.get(type_str): return TYPE_TO_CLASS[type_str] raise ValueError(f"Could not find class for type {type_str}") + class LoopContinue(Exception): pass + class LoopBreak(Exception): pass + class LoopExit(Exception): pass + class StopModule(Exception): pass + class StopGraphExecution(Exception): pass + class ModuleError(Exception): pass + PASSTHROUGH_ERRORS = ( - ExitScene, - ResetScene, - RestartSceneLoop, + ExitScene, + ResetScene, + RestartSceneLoop, ActedAsCharacter, LoopContinue, LoopBreak, @@ -97,170 +111,181 @@ PASSTHROUGH_ERRORS = ( GenerationCancelled, ) + class UNRESOLVED: def __bool__(self): return False - + def __str__(self): return "" - + def __repr__(self): return "" + class NodeVerbosity(IntEnum): SILENT = 0 NORMAL = 1 VERBOSE = 2 + async_signals.register("nodes_node_state") + class InputValueError(ValueError): - def __init__(self, node:"Node", input_name:str, message:str): + def __init__(self, node: "Node", input_name: str, message: str): self.node = node self.input_name = input_name super().__init__(f"Error in node {node.title} input {input_name}: {message}") - - -def load_extended_components(file_path:str, node_data:dict): +def load_extended_components(file_path: str, node_data: dict): """ Loads all extended components from a file """ - + with open(file_path, "r") as f: data = json.load(f) - + log.debug("loading extended components", file_path=file_path) - + if data.get("extends"): load_extended_components(data["extends"], node_data) - + for node_id, node in data.get("nodes", {}).items(): if node_id not in node_data.get("nodes"): node_data["nodes"][node_id] = node node_data["nodes"][node_id]["inherited"] = True - + for edge in data.get("edges", []): node_data["edges"][edge] = data["edges"][edge] - + for group in data.get("groups", []): group["inherited"] = True node_data["groups"].append(group) - + for comment in data.get("comments", []): comment["inherited"] = True node_data["comments"].append(comment) - + log.debug("loaded extended components", file_path=file_path) -def dynamic_node_import(node_data: dict, registry_name:str, registry_container:dict|None=None) -> "Graph | Loop": + +def dynamic_node_import( + node_data: dict, registry_name: str, registry_container: dict | None = None +) -> "Graph | Loop": """ Import a node definition from data - + If the node doesn't exist, dynamically create it a class for it using the data using Loop if registry name contains "Loop" otherwise use Graph """ - + base_type = node_data.get("base_type") node_cls = BASE_TYPES.get(base_type) - + if not node_cls: - raise ValueError(f"Cannont import node data with base type {node_data.get('base_type')}") - + raise ValueError( + f"Cannont import node data with base type {node_data.get('base_type')}" + ) + if node_data.get("extends"): log.debug("loading extended components", extends=node_data["extends"]) load_extended_components(node_data["extends"], node_data) - + @register(registry_name, container=registry_container) class DynamicNode(node_cls): def __init__(self, *args, **kwargs): node_data_copy = node_data.copy() node_data_copy.update(kwargs) super().__init__(*args, **node_data_copy) + DynamicNode.__name__ = registry_name.split("/")[-1] DynamicNode.__dynamic_imported__ = True DynamicNode._base_type = base_type - + return DynamicNode + def get_ancestors_with_forks(graph: nx.DiGraph, node_id: str) -> set[str]: """ Returns a set of node IDs that are ancestors of the given node, plus any nodes in forked branches that don't lead to the target. - + A drop-in replacement for nx.ancestors() with extended functionality. - + Parameters: - graph: A NetworkX directed graph - node_id: The target node ID - + Returns: - A set of node IDs including ancestors and forked paths """ # Get direct ancestors (standard behavior) ancestors = nx.ancestors(graph, node_id) - + # Find all forks from ancestors forked_nodes = set() for ancestor_id in ancestors: # For each ancestor, find its descendants that aren't already in ancestors descendants = nx.descendants(graph, ancestor_id) # Add descendants that aren't ancestors of our target node and aren't the target - forked_nodes.update(desc for desc in descendants - if desc not in ancestors and desc != node_id) - + forked_nodes.update( + desc for desc in descendants if desc not in ancestors and desc != node_id + ) + # Combine direct ancestors with forked nodes return ancestors.union(forked_nodes) + class NodeStyle(pydantic.BaseModel): title_color: str | None = None node_color: str | None = None icon: str | None = None auto_title: str | None = None + class NodeState(pydantic.BaseModel): node_id: str - start_time: float | None = pydantic.Field(default_factory=time.time) + start_time: float | None = pydantic.Field(default_factory=time.time) end_time: float | None = None deactivated: bool = False error: str | None = None - + input_values: dict[str, Any] = pydantic.Field(default_factory=dict) output_values: dict[str, Any] = pydantic.Field(default_factory=dict) properties: dict[str, Any] = pydantic.Field(default_factory=dict) - - + def __init__(self, node: "NodeBase", state: "GraphState", **kwargs): super().__init__(node_id=node.id, **kwargs) - + self.input_values = {socket.name: socket.value for socket in node.inputs} self.output_values = {socket.name: socket.value for socket in node.outputs} self.properties = node.properties.copy() - + def __eq__(self, value) -> bool: try: return self.node_id == value.node_id except AttributeError: return False - + def __hash__(self): return hash(self.node_id) - + def __lt__(self, value) -> bool: if not isinstance(value, NodeState): return NotImplemented return self.node_id < value.node_id - + def __gt__(self, value) -> bool: if not isinstance(value, NodeState): return NotImplemented return self.node_id > value.node_id - + def __str__(self): return f"NodeState {self.node_id}" - + def __repr__(self): return f"NodeState {self.node_id}" @@ -274,7 +299,7 @@ class NodeState(pydantic.BaseModel): r.maxlevel = 1 r.maxlist = 10 r.maxstring = 255 - + return { "node_id": self.node_id, "start_time": self.start_time, @@ -283,21 +308,20 @@ class NodeState(pydantic.BaseModel): "error": self.error, "input_values": {k: r.repr(v) for k, v in self.input_values.items()}, "output_values": {k: r.repr(v) for k, v in self.output_values.items()}, - "properties": {k: r.repr(v) for k, v in self.properties.items()} + "properties": {k: r.repr(v) for k, v in self.properties.items()}, } - - + class GraphState(pydantic.BaseModel): data: dict[str, Any] = pydantic.Field(default_factory=dict) outer: "GraphState | None" = None - + shared: dict[str, Any] = pydantic.Field(default_factory=dict) - + graph: "Graph | None" = None - + stack: list[NodeState] = pydantic.Field(default_factory=list) - + verbosity: NodeVerbosity = NodeVerbosity.NORMAL @property @@ -307,86 +331,90 @@ class GraphState(pydantic.BaseModel): "stack": [node_state.flattened for node_state in self.stack], } except Exception as e: - log.error("error dumping stack (stack probably contains circular references)", error=e) + log.error( + "error dumping stack (stack probably contains circular references)", + error=e, + ) self.stack = [] return {"stack": []} - - def node_property_key(self, node:"NodeBase", name:str) -> str: + + def node_property_key(self, node: "NodeBase", name: str) -> str: return f"{node.id}.{name}" - - def set_node_property(self, node:"NodeBase", name:str, value:Any): + + def set_node_property(self, node: "NodeBase", name: str, value: Any): self.data[self.node_property_key(node, name)] = value - - def get_node_property(self, node:"NodeBase", name:str) -> Any: - return self.data.get(self.node_property_key(node, name), node.properties.get(name, UNRESOLVED)) - - def node_socket_value_key(self, node:"NodeBase", socket_name:str) -> str: + + def get_node_property(self, node: "NodeBase", name: str) -> Any: + return self.data.get( + self.node_property_key(node, name), node.properties.get(name, UNRESOLVED) + ) + + def node_socket_value_key(self, node: "NodeBase", socket_name: str) -> str: return f"{node.id}__socket.{socket_name}" - - def set_node_socket_value(self, node:"NodeBase", socket_name:str, value:Any): + + def set_node_socket_value(self, node: "NodeBase", socket_name: str, value: Any): self.data[self.node_socket_value_key(node, socket_name)] = value - - def get_node_socket_value(self, node:"NodeBase", socket_name:str) -> Any: + + def get_node_socket_value(self, node: "NodeBase", socket_name: str) -> Any: return self.data.get(self.node_socket_value_key(node, socket_name), UNRESOLVED) - def node_socket_state_key(self, node:"NodeBase", socket_name:str) -> str: + def node_socket_state_key(self, node: "NodeBase", socket_name: str) -> str: return f"{node.id}__socket_deactivated.{socket_name}" - - def set_node_socket_state(self, node:"NodeBase", socket_name:str, value:bool): + + def set_node_socket_state(self, node: "NodeBase", socket_name: str, value: bool): self.data[self.node_socket_state_key(node, socket_name)] = value - - def get_node_socket_state(self, node:"NodeBase", socket_name:str) -> bool: + + def get_node_socket_state(self, node: "NodeBase", socket_name: str) -> bool: return self.data.get(self.node_socket_state_key(node, socket_name), False) class GraphContext: - def __init__(self, outer_state: GraphState = None, graph: "Graph" = None): self.outer_state = outer_state self.graph = graph self.token = None - + def __enter__(self) -> GraphState: state = GraphState(outer=self.outer_state, graph=self.graph) state.shared = self.outer_state.shared if self.outer_state else {} state.stack = self.outer_state.stack if self.outer_state else [] self.token = graph_state.set(state) return state - + def __exit__(self, exc_type, exc_value, traceback): graph_state.reset(self.token) + class SaveContext: - def __enter__(self): self.token = save_state.set(True) return self - + def __exit__(self, exc_type, exc_value, traceback): save_state.reset(self.token) - + class Socket(pydantic.BaseModel): id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) name: str node: "NodeBase | None" = pydantic.Field(exclude=True, default=None) - + source: "Socket" = pydantic.Field(default=None, exclude=True) optional: bool = False group: str | None = None - + socket_type: str | list = "any" - + @classmethod def as_bool(cls, value: Any) -> bool: if value is UNRESOLVED: return False return bool(value) - + @property def value(self): try: - state:GraphState = graph_state.get() + state: GraphState = graph_state.get() except LookupError: # we dont have a state, so we can't get the value return UNRESOLVED @@ -397,31 +425,31 @@ class Socket(pydantic.BaseModel): @value.setter def value(self, value): try: - state:GraphState = graph_state.get() + state: GraphState = graph_state.get() except LookupError: # we dont have a state, so we can't set the value return - + state.set_node_socket_value(self.node, self.name, value) @property def deactivated(self) -> bool: try: - state:GraphState = graph_state.get() + state: GraphState = graph_state.get() except LookupError: # we dont have a state, so we can't get the socket activation state return True - + return state.get_node_socket_state(self.node, self.name) - + @deactivated.setter def deactivated(self, value: bool): try: - state:GraphState = graph_state.get() + state: GraphState = graph_state.get() except LookupError: # we dont have a state, so we can't set the socket activation state return - + state.set_node_socket_state(self.node, self.name, value) @property @@ -430,35 +458,36 @@ class Socket(pydantic.BaseModel): def __hash__(self): return hash(self.id) - + def __eq__(self, other): return self.id == other.id def __str__(self): return f"{self.node.title}.{self.name}" if self.node else self.name - + def __repr__(self): return str(self) + class PropertyField(pydantic.BaseModel): """ Describe a property field for a node """ - + name: str description: str type: str default: Any = None choices: list[Any] = None readonly: bool = False - + step: float | int | None = None min: float | int | None = None max: float | int | None = None - + # if true value will not be saved in the graph, past the initial value ephemeral: bool = False - + generate_choices: Callable | None = pydantic.Field(default=None, exclude=True) def model_dump(self, **kwargs): @@ -468,6 +497,7 @@ class PropertyField(pydantic.BaseModel): data["choices"] = self.generate_choices() return data + class NodeBase(pydantic.BaseModel): title: str = "Node" id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) @@ -481,179 +511,180 @@ class NodeBase(pydantic.BaseModel): registry: str | None = None _registry: ClassVar[str | None] = None - + _export_definition: ClassVar[bool] = True - + _isolated: ClassVar[bool] = False - + _base_type: ClassVar[str] = "" - + _module_path: ClassVar[str] = "" - - @pydantic.computed_field(description="Base type") + + @pydantic.computed_field(description="Base type") @property def base_type(self) -> str: return self._base_type - + @property def field_definitions(self) -> dict[str, PropertyField]: """ Returns a dictionary of property field definitions - + Cycle through self.properties and return the PropertyField for each property """ - + fields = {} - + for name, value in self.properties.items(): fields[name] = self.get_property_field(name) - + # if the class has a Fields object, add any remaining fields from that if hasattr(self.__class__, "Fields"): for name, value in self.__class__.Fields.__dict__.items(): # if the field is a PropertyField and it's not already in the fields dictionary, add it if isinstance(value, PropertyField) and name not in fields: fields[name] = value - + return fields - + def __init__(self, *args, **kwargs): - if kwargs.get("title", "Node") == "Node": title = self.__class__.__name__ # replace camel case with spaces - title = re.sub(r'(? Any: - if isinstance(data, dict) and 'properties' in data: - properties = data['properties'] - data['properties'] = { - k: UNRESOLVED if v in ('UNRESOLVED', None) else v + if isinstance(data, dict) and "properties" in data: + properties = data["properties"] + data["properties"] = { + k: UNRESOLVED if v in ("UNRESOLVED", None) else v for k, v in properties.items() } return data - + def get_output_socket(self, name: str) -> Socket: for socket in self.outputs: if socket.name == name: return socket return None - + def get_input_socket(self, name: str) -> Socket: for socket in self.inputs: if socket.name == name: return socket return None - + def add_input(self, name: str, **kwargs) -> Socket: socket = Socket(name=name, node=self, **kwargs) self.inputs.append(socket) return socket - + def remove_input(self, name: str): socket = self.get_input_socket(name) if socket: self.inputs.remove(socket) - + def add_output(self, name: str, **kwargs) -> Socket: socket = Socket(name=name, node=self, **kwargs) self.outputs.append(socket) return socket - + def remove_output(self, name: str): socket = self.get_output_socket(name) if socket: self.outputs.remove(socket) - + def get_property_field(self, name: str) -> PropertyField: """ checks self.Fields for a field with the given name - + it will be defined as a meta class in the derived class - + returns the field if it exists, otherwise return a generic field - + class Node: class Fields: name = PropertyField(name="number", description="A number", type=int) """ - - - + if not hasattr(self.__class__, "Fields"): - type_str = PYTHON_TYPE_TO_STRING.get(str(type(self.properties.get(name, "")))) or "str" + type_str = ( + PYTHON_TYPE_TO_STRING.get(str(type(self.properties.get(name, "")))) + or "str" + ) return PropertyField(name=name, description=name, type=type_str) - + FieldMeta = self.__class__.Fields - + if not hasattr(FieldMeta, name): - type_str = PYTHON_TYPE_TO_STRING.get(str(type(self.properties.get(name, "")))) or "str" + type_str = ( + PYTHON_TYPE_TO_STRING.get(str(type(self.properties.get(name, "")))) + or "str" + ) return PropertyField(name=name, description=name, type=type_str) - + return getattr(FieldMeta, name) - - + def set_property(self, name: str, value: Any, state: GraphState | None = None): """Set a property value""" if state is None: self.properties[name] = value else: state.set_node_property(self, name, value) - + def get_property(self, name: str, state: GraphState | None = None) -> Any: """Get a property value""" - + if state is None: try: - state:GraphState = graph_state.get() + state: GraphState = graph_state.get() except LookupError: return self.properties.get(name, UNRESOLVED) - + return state.get_node_property(self, name) - + def get_input_value(self, name: str) -> Any: """Get value for a specific input, falling back to property if input not connected""" # Find matching input socket @@ -664,21 +695,23 @@ class NodeBase(pydantic.BaseModel): return socket.value # Otherwise fall back to property break - + return self.get_property(name) - + def get_input_values(self) -> dict[str, Any]: - """Get all input values as a dictionary mapped by socket name, + """Get all input values as a dictionary mapped by socket name, falling back to properties for unconnected inputs""" values = {} - - names = set([socket.name for socket in self.inputs] + list(self.properties.keys())) - + + names = set( + [socket.name for socket in self.inputs] + list(self.properties.keys()) + ) + for name in names: values[name] = self.get_input_value(name) - + return values - + def set_output_values(self, values: dict[str, Any]): """Set output values by socket name""" for socket in self.outputs: @@ -689,152 +722,158 @@ class NodeBase(pydantic.BaseModel): """Override this method in derived classes to implement node behavior""" pass - - def check_is_available(self, state:GraphState) -> bool: + def check_is_available(self, state: GraphState) -> bool: """ A node should only run if: 1. All of its inputs are available (connected to non-deactivated sources), taking into account required groups 2. At least one of its outputs will have an effect (connects to a path that isn't deactivated) """ - + if self._isolated: return False - + # Group sockets by their group grouped_sockets = {} ungrouped_sockets = [] - + for socket in self.inputs: if socket.optional: continue - + if socket.group: if socket.group not in grouped_sockets: grouped_sockets[socket.group] = [] grouped_sockets[socket.group].append(socket) else: ungrouped_sockets.append(socket) - + # Check ungrouped sockets - all must be available for socket in ungrouped_sockets: - if socket.source is None or socket.source.deactivated or socket.value is UNRESOLVED: + if ( + socket.source is None + or socket.source.deactivated + or socket.value is UNRESOLVED + ): if self.get_property(socket.name) is UNRESOLVED: - if state.verbosity >= NodeVerbosity.VERBOSE: - log.warning(f"Node {self.title} input {socket.name} is not available, missing socket {socket.name}") - + log.warning( + f"Node {self.title} input {socket.name} is not available, missing socket {socket.name}" + ) + for out_socket in self.outputs: out_socket.deactivated = True return False - + # Check grouped sockets - at least one socket in each group must be available for group_sockets in grouped_sockets.values(): group_has_active = False - + for socket in group_sockets: # Check if socket has an active source or a property value - if (socket.source and not socket.source.deactivated and socket.value is not UNRESOLVED) or self.get_property(socket.name) is not UNRESOLVED: + if ( + socket.source + and not socket.source.deactivated + and socket.value is not UNRESOLVED + ) or self.get_property(socket.name) is not UNRESOLVED: group_has_active = True break - + if not group_has_active: if state.verbosity >= NodeVerbosity.VERBOSE: - log.warning(f"Node {self.title} group {group_sockets[0].group} is not available") + log.warning( + f"Node {self.title} group {group_sockets[0].group} is not available" + ) # If no socket in the group is active, deactivate outputs and return False for out_socket in self.outputs: out_socket.deactivated = True return False - + # If we have no outputs, we're an endpoint node - run if we have our inputs if not self.outputs: return True - + # Keep track of visited nodes to handle cycles visited = set() - + def has_active_path(current_socket: Socket, visited_nodes: set) -> bool: # If we've seen this node already, skip it to avoid cycles if current_socket.node in visited_nodes: return False - + visited_nodes.add(current_socket.node) - - # If this output socket is already deactivated, path is dead + + # If this output socket is already deactivated, path is dead if current_socket.deactivated: return False - + # If any input socket uses this as a source and isn't deactivated, # this is a valid path for node in current_socket.node.outputs: if not node.deactivated: return True - + # For each output, look for nodes that use it as input # and check their outputs recursively if node.source and not node.source.deactivated: if has_active_path(node.source, visited_nodes.copy()): return True - + return False - + # Check if any output path leads somewhere active - is_available = any(has_active_path(socket, visited.copy()) - for socket in self.outputs) - + is_available = any( + has_active_path(socket, visited.copy()) for socket in self.outputs + ) + # If not available, mark all our outputs as deactivated if not is_available: for socket in self.outputs: socket.deactivated = True - + return is_available - - def is_set(self, value:Any, none_is_set:bool=False) -> bool: + + def is_set(self, value: Any, none_is_set: bool = False) -> bool: """ Helper function to check if a value is set """ - + if none_is_set: return value is not UNRESOLVED return value is not UNRESOLVED and value is not None - - def require_input(self, input_name:str, none_is_set:bool=False) -> Any: + + def require_input(self, input_name: str, none_is_set: bool = False) -> Any: """ Require an input to be set and return it - + If the input is not set, raise an InputValueError - + If none_is_set is True, None is considered a set value """ - + value = self.get_input_value(input_name) - + if not self.is_set(value, none_is_set): - raise InputValueError( - self, - input_name, - f"Value is not set: {value}" - ) - - return value - - def normalized_input_value(self, input_name:str, none_is_set:bool=False) -> Any: - """ - Helper function to check if a value is set - - UNRESOLVED values are returned as None - """ - - value = self.get_input_value(input_name) - - if not self.is_set(value, none_is_set): - return None - + raise InputValueError(self, input_name, f"Value is not set: {value}") + return value - def require_number_input(self, name:str, types:tuple=(int, float)): - + def normalized_input_value(self, input_name: str, none_is_set: bool = False) -> Any: + """ + Helper function to check if a value is set + + UNRESOLVED values are returned as None + """ + + value = self.get_input_value(input_name) + + if not self.is_set(value, none_is_set): + return None + + return value + + def require_number_input(self, name: str, types: tuple = (int, float)): value = self.require_input(name) - + if isinstance(value, str): try: if float in types: @@ -843,173 +882,152 @@ class NodeBase(pydantic.BaseModel): value = int(value) except ValueError: raise InputValueError(self, name, "Invalid number") - + if not isinstance(value, types): raise InputValueError(self, name, "Value must be a number") - + return value + @base_node_type("core/Node") class Node(NodeBase): inputs: list[Socket] = pydantic.Field(default_factory=list) outputs: list[Socket] = pydantic.Field(default_factory=list) - + + class Entry(Node): def __init__(self, title="Entry", **kwargs): super().__init__(title=title, **kwargs) self.add_output("state") - - async def run(self, state:GraphState): - self.set_output_values({ - "state": state - }) + + async def run(self, state: GraphState): + self.set_output_values({"state": state}) + class Router(Node): - selector: Callable = pydantic.Field(default_factory=lambda state: 0) - + num_outputs: int = 2 - + def __init__(self, num_outputs: int, title="Router", **kwargs): super().__init__(num_outputs=num_outputs, title=title, **kwargs) - + def setup(self): for i in range(self.num_outputs): self.add_output(f"output_{i}") self.add_input("input") - + async def run(self, state: GraphState): route_to = self.selector(state) - + for idx, socket in enumerate(self.outputs): if idx != route_to: socket.deactivated = True else: - print(f"Setting output {socket.name} to {self.get_input_value('input')}") - self.set_output_values({ - socket.name: self.get_input_value("input") - }) + print( + f"Setting output {socket.name} to {self.get_input_value('input')}" + ) + self.set_output_values({socket.name: self.get_input_value("input")}) @register("core/Input") class Input(Node): - class Fields: input_type = PropertyField( name="input_type", - description="Input Type", - type="str", - default="any", - choices=TYPE_CHOICES, - generate_choices=lambda: TYPE_CHOICES - ) - - input_name = PropertyField( - name="input_name", - description="Input Name", + description="Input Type", type="str", - default="state" + default="any", + choices=TYPE_CHOICES, + generate_choices=lambda: TYPE_CHOICES, ) - + + input_name = PropertyField( + name="input_name", description="Input Name", type="str", default="state" + ) + input_optional = PropertyField( name="input_optional", description="Input Optional", type="bool", - default=False - ) - - input_group = PropertyField( - name="input_group", - description="Input Group", - type="str", - default="" - ) - - num = PropertyField( - name="num", - description="Number", - type="int", - default=0 + default=False, ) - + input_group = PropertyField( + name="input_group", description="Input Group", type="str", default="" + ) + + num = PropertyField(name="num", description="Number", type="int", default=0) + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#2d2c39", title_color="#312e57", - icon="F02FA", #import - auto_title="IN {input_name}" + icon="F02FA", # import + auto_title="IN {input_name}", ) - def __init__(self, title="Input Socket", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("input_type", "any") self.set_property("input_name", "state") self.set_property("input_optional", False) self.set_property("input_group", "") self.set_property("num", 0) - + self.add_output("value") + @register("core/Output") class Output(Node): - class Fields: output_type = PropertyField( name="output_type", - description="Output Type", - type="str", - default="", - choices=TYPE_CHOICES, - generate_choices=lambda: TYPE_CHOICES - ) - - output_name = PropertyField( - name="output_name", - description="Output Name", + description="Output Type", type="str", - default="state" + default="", + choices=TYPE_CHOICES, + generate_choices=lambda: TYPE_CHOICES, ) - - num = PropertyField( - name="num", - description="Number", - type="int", - default=0 + + output_name = PropertyField( + name="output_name", description="Output Name", type="str", default="state" ) - + + num = PropertyField(name="num", description="Number", type="int", default=0) + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#2d392c", title_color="#30572e", - icon="F0207", #export + icon="F0207", # export auto_title="OUT {output_name}", ) - + def __init__(self, title="Output Socket", **kwargs): super().__init__(title=title, **kwargs) - - def setup(self): + + def setup(self): self.set_property("output_type", "any") self.set_property("output_name", "state") self.set_property("num", 0) - + self.add_input("value", optional=True) + @register("core/ModuleProperty") class ModuleProperty(Node): """ A node that can be placed to define a property of a Graph - + Properties: - + - property_name: The name of the property - proeprty_type: The type of the property - default: The default value of the property @@ -1017,14 +1035,13 @@ class ModuleProperty(Node): - readonly: Whether the property is readonly - ephemeral: Whether the property is ephemeral - required: Whether the property is required - + Outputs: - + - name: The name of the property - value: The value of the property """ - - + class Fields: property_name = PropertyField( name="property_name", @@ -1037,9 +1054,7 @@ class ModuleProperty(Node): description="Property Type", type="str", default="", - choices=[ - "str", "bool", "int", "float", "text" - ] + choices=["str", "bool", "int", "float", "text"], ) default = PropertyField( name="default", @@ -1066,18 +1081,17 @@ class ModuleProperty(Node): default=0, min=0, ) - - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#2c3339", title_color="#2e4657", - icon="F0AE7", #variable - auto_title="PROP {property_name}" + icon="F0AE7", # variable + auto_title="PROP {property_name}", ) - + @property def to_property_field(self) -> PropertyField: return PropertyField( @@ -1087,10 +1101,10 @@ class ModuleProperty(Node): default=self.cast_value(self.get_property("default")), choices=self.get_property("choices"), ) - + def __init__(self, title="Module Property", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("property_name", "") self.set_property("property_type", "") @@ -1102,11 +1116,10 @@ class ModuleProperty(Node): self.add_output("value") def cast_value(self, value: Any) -> Any: - # if UNRESOLVED return default if value is UNRESOLVED: value = self.get_property("default") - + if self.get_property("property_type") in ["str", "text"]: return str(value) elif self.get_property("property_type") == "bool": @@ -1124,24 +1137,24 @@ class ModuleProperty(Node): return float(value) return str(value) + @register("core/Route") class Route(Node): """ Simply passes the value of the input to the output """ - + def __init__(self, title="Route", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value") self.add_output("value") - + async def run(self, state: GraphState): value = self.get_input_value("value") - self.set_output_values({ - "value": value - }) + self.set_output_values({"value": value}) + @register("core/Watch") class Watch(Node): @@ -1155,75 +1168,68 @@ class Watch(Node): return NodeStyle( node_color="#2c3339", title_color="#2e4657", - icon="F06D0", #eye-outline + icon="F06D0", # eye-outline ) - def __init__(self, title="Watch", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value") self.add_output("value") - + async def run(self, state: GraphState): value = self.get_input_value("value") - + if state.shared.get("creative_mode"): log.debug(f"Watch:{self.title}", value=value) - - self.set_output_values({ - "value": value - }) - + + self.set_output_values({"value": value}) + + @register("core/Stage") class Stage(Node): """ A node that can be connected in or out and defines a stage level for the nodes connected to it - + This stage level can be used to control the order of execution of nodes in the graph, the lowest stage will be executed first. - + Inputs: - + - state: Any value to pass through. If not connected, defaults to True - - state_b: Any value to pass through. + - state_b: Any value to pass through. - state_c: Any value to pass through. - state_d: Any value to pass through. - + Outputs: - + - state: The value of the input state or True if corresponding input is not connected - state_b: The value of the input state_b - state_c: The value of the input state_c - state_d: The value of the input state_d """ - + class Fields: stage = PropertyField( - name="stage", - description="Stage", - type="int", - default=0, - min=0, - step=1 + name="stage", description="Stage", type="int", default=0, min=0, step=1 ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#2c2c38", title_color="#343055", - icon="F0603", #priority-high - auto_title="Stage {stage}" + icon="F0603", # priority-high + auto_title="Stage {stage}", ) - + def __init__(self, title="Stage", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state", optional=True) self.add_input("state_b", optional=True) @@ -1236,23 +1242,25 @@ class Stage(Node): self.add_output("state_d") async def run(self, state: GraphState): - state_value = self.get_input_value("state") state_b_value = self.get_input_value("state_b") state_c_value = self.get_input_value("state_c") state_d_value = self.get_input_value("state_d") - + # unconnected inputs default to True in the output - + if state_value is UNRESOLVED and not self.get_input_socket("state").source: state_value = True - - self.set_output_values({ - "state": state_value, - "state_b": state_b_value, - "state_c": state_c_value, - "state_d": state_d_value, - }) + + self.set_output_values( + { + "state": state_value, + "state_b": state_b_value, + "state_c": state_c_value, + "state_d": state_d_value, + } + ) + def validate_node( v: Any, @@ -1262,24 +1270,23 @@ def validate_node( # If it's already a Node instance, return it if isinstance(v, NodeBase): return v - + # If it's a dict, check registry and instantiate appropriate class if isinstance(v, dict): - registry_name = v.get('registry') - #print(f"Validating node with registry: {registry_name}") + registry_name = v.get("registry") + # print(f"Validating node with registry: {registry_name}") if registry_name: node_cls = get_node(registry_name) - #print(f"Found node class: {node_cls}") + # print(f"Found node class: {node_cls}") if node_cls: return node_cls(**v) - + raise ValueError(f"Could not validate node: {v}") + # Create annotated type for nodes with registry validation -RegistryNode = Annotated[ - NodeBase, - pydantic.WrapValidator(validate_node) -] +RegistryNode = Annotated[NodeBase, pydantic.WrapValidator(validate_node)] + class Group(pydantic.BaseModel): title: str = "Group" @@ -1291,6 +1298,7 @@ class Group(pydantic.BaseModel): font_size: int = 24 inherited: bool = False + class Comment(pydantic.BaseModel): text: str = "Comment" x: int = 0 @@ -1298,48 +1306,44 @@ class Comment(pydantic.BaseModel): width: int = 200 inherited: bool = False + @register("util/ModuleStyle") class ModuleStyle(Node): """ An isolated node that will define the Graph's style """ - + _isolated: ClassVar[bool] = True - + class Fields: node_color = PropertyField( - name="node_color", - description="Node Color", - type="str", - default=UNRESOLVED + name="node_color", description="Node Color", type="str", default=UNRESOLVED ) title_color = PropertyField( name="title_color", description="Title Color", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) auto_title = PropertyField( - name="auto_title", - description="Auto Title", - type="str", - default=UNRESOLVED + name="auto_title", description="Auto Title", type="str", default=UNRESOLVED ) icon = PropertyField( name="icon", description="Icon (Material Icon Codepoint)", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + def __init__(self, title="Module Style", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("node_color", UNRESOLVED) self.set_property("title_color", UNRESOLVED) self.set_property("auto_title", UNRESOLVED) self.set_property("icon", UNRESOLVED) + def get_style(self) -> NodeStyle: return NodeStyle( node_color=self.get_property("node_color"), @@ -1347,24 +1351,25 @@ class ModuleStyle(Node): auto_title=self.get_property("auto_title"), icon=self.get_property("icon"), ) - + async def run(self, state: GraphState): pass + @base_node_type("core/Graph") class Graph(NodeBase): nodes: dict[str, RegistryNode] = pydantic.Field(default_factory=dict) - edges: dict[str, list[str]] = pydantic.Field(default_factory=dict) + edges: dict[str, list[str]] = pydantic.Field(default_factory=dict) sockets: dict[str, Socket] = pydantic.Field(default_factory=dict, exclude=True) groups: list[Group] = pydantic.Field(default_factory=list) comments: list[Comment] = pydantic.Field(default_factory=list) extends: str | None = None - + error_handlers: list[Callable] = pydantic.Field(default_factory=list, exclude=True) callbacks: list[Callable] = pydantic.Field(default_factory=list, exclude=True) - + _interrupt: bool = False - + @property def input_nodes(self) -> list[Input]: return [node for node in self.nodes.values() if isinstance(node, Input)] @@ -1372,84 +1377,88 @@ class Graph(NodeBase): @property def output_nodes(self) -> list[Output]: return [node for node in self.nodes.values() if isinstance(node, Output)] - + @property def module_property_nodes(self) -> list[ModuleProperty]: - nodes = [node for node in self.nodes.values() if isinstance(node, ModuleProperty)] + nodes = [ + node for node in self.nodes.values() if isinstance(node, ModuleProperty) + ] # sort by num property nodes.sort(key=lambda x: x.get_property("num")) - + return nodes @pydantic.computed_field(description="Inputs") @property def inputs(self) -> list[Socket]: - if hasattr(self, "_inputs"): return self._inputs - + # find sub nodes of Input type and dynamically output Socket types inputs = [] - + nodes = [] - + # collect nodes and sort by num property for node in self.input_nodes: nodes.append(node) - + nodes.sort(key=lambda x: x.get_property("num")) - + for node in nodes: - inputs.append(Socket( - name=node.get_property("input_name"), - socket_type=node.get_property("input_type"), - optional=node.get_property("input_optional"), - group=node.get_property("input_group"), - node=self, - )) - + inputs.append( + Socket( + name=node.get_property("input_name"), + socket_type=node.get_property("input_type"), + optional=node.get_property("input_optional"), + group=node.get_property("input_group"), + node=self, + ) + ) + self._inputs = inputs - + return inputs - - @pydantic.computed_field(description="Outputs") + + @pydantic.computed_field(description="Outputs") @property def outputs(self) -> list[Socket]: - if hasattr(self, "_outputs"): return self._outputs - + # find sub nodes of Output type and dynamically output Socket types outputs = [] - + nodes = [] - + for node in self.output_nodes: nodes.append(node) - + nodes.sort(key=lambda x: x.get_property("num")) - + for node in nodes: - outputs.append(Socket( - name=node.get_property("output_name"), - socket_type=node.get_property("output_type"), - node=self, - )) - + outputs.append( + Socket( + name=node.get_property("output_name"), + socket_type=node.get_property("output_type"), + node=self, + ) + ) + self._outputs = outputs - + return outputs - + @pydantic.computed_field(description="Module Fields") @property def module_properties(self) -> dict[str, PropertyField]: # Dynamically find all ModuleProperty nodes and return them # as a list of PropertyField objects - + if hasattr(self, "_module_properties"): return self._module_properties - + properties = {} for node in self.module_property_nodes: name = node.get_property("property_name") @@ -1457,71 +1466,75 @@ class Graph(NodeBase): properties[name] = node.to_property_field else: log.warning("Duplicate module property", name=name) - + self._module_properties = properties - + return properties - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle | None: """ Find the style of the ModuleStyle node in the graph and return it """ - + for _node in self.nodes.values(): if isinstance(_node, ModuleStyle): return _node.get_style() - + return None def model_dump(self, **kwargs) -> dict: """ If save_state is active, drop all inherited nodes, groups, and comments from the serialized output. - + Nodes that are dropped need also be removed from the edges dictionary """ data = super().model_dump(**kwargs) - + try: if save_state.get(): # Filter out inherited nodes data["nodes"] = { - node_id: node - for node_id, node in data["nodes"].items() + node_id: node + for node_id, node in data["nodes"].items() if not node.get("inherited", False) } - + # Filter out inherited groups data["groups"] = [ - group for group in data["groups"] + group + for group in data["groups"] if not group.get("inherited", False) ] - + # Filter out inherited comments data["comments"] = [ - comment for comment in data["comments"] + comment + for comment in data["comments"] if not comment.get("inherited", False) ] - + # Remove edges that reference dropped nodes dropped_node_ids = { - node_id for node_id, node in self.nodes.items() - if node.inherited + node_id for node_id, node in self.nodes.items() if node.inherited } - + data["edges"] = { - output_id: input_ids + output_id: input_ids for output_id, input_ids in data["edges"].items() if output_id.split(".")[0] not in dropped_node_ids - and all(input_id.split(".")[0] not in dropped_node_ids for input_id in input_ids) + and all( + input_id.split(".")[0] not in dropped_node_ids + for input_id in input_ids + ) } - + except LookupError: # save_state not set, return full data pass - + return data def reinitialize(self) -> "Graph": @@ -1534,12 +1547,12 @@ class Graph(NodeBase): def reset(self): """ Reset all _value properties in all sockets - """ + """ for node in self.nodes.values(): for socket in node.inputs + node.outputs: socket.value = UNRESOLVED socket.deactivated = False - + self.reinitialize() def reset_sockets(self): @@ -1551,26 +1564,29 @@ class Graph(NodeBase): socket.value = UNRESOLVED socket.deactivated = False - def ensure_connections(self): """ Loop through edges and ensure all sockets are connected """ - + for output_socket_id, input_socket_ids in self.edges.items(): - output_node_id, output_socket_name = output_socket_id.split(".",1) + output_node_id, output_socket_name = output_socket_id.split(".", 1) output_node = self.nodes[output_node_id] output_socket = output_node.get_output_socket(output_socket_name) - + for input_socket_id in input_socket_ids: - input_node_id, input_socket_name = input_socket_id.split(".",1) + input_node_id, input_socket_name = input_socket_id.split(".", 1) input_node = self.nodes[input_node_id] input_socket = input_node.get_input_socket(input_socket_name) - + if not input_socket: - log.warning("Input socket not found", input_socket_name=input_socket_name, input_node_id=input_node_id) + log.warning( + "Input socket not found", + input_socket_name=input_socket_name, + input_node_id=input_node_id, + ) continue - + if not input_socket.source: self.connect(output_socket, input_socket) @@ -1578,7 +1594,7 @@ class Graph(NodeBase): """ Reset all ephemeral properties in all nodes """ - + for node in self.nodes.values(): for property_name in node.properties: field = node.get_property_field(property_name) @@ -1589,26 +1605,28 @@ class Graph(NodeBase): """ Loops through all nodes and sets their socket.node references """ - + for node in self.nodes.values(): for socket in node.inputs + node.outputs: socket.node = node self.sockets[socket.full_id] = socket - + return self - + def set_socket_source_references(self) -> "Graph": """ Loops through all nodes and their input sockets and sets the `source` reference based on the edge connections """ - + for node in self.nodes.values(): for socket in node.inputs: for output_socket_id, input_socket_ids in self.edges.items(): - output_node_id, output_socket_name = output_socket_id.split(".",1) + output_node_id, output_socket_name = output_socket_id.split(".", 1) if socket.id in input_socket_ids: - output_socket = self.sockets[f"{output_node_id}.{output_socket_name}"] + output_socket = self.sockets[ + f"{output_node_id}.{output_socket_name}" + ] socket.source = output_socket break return self @@ -1618,60 +1636,60 @@ class Graph(NodeBase): def add_node(self, node: NodeBase): self.nodes[node.id] = node - + for socket in node.inputs + node.outputs: self.sockets[socket.id] = socket - + def connect(self, output_socket: Socket | str, input_socket: Socket | str): """ Connect an output socket to an input socket. One output socket can connect to multiple input sockets. """ - + if isinstance(output_socket, str): output_socket = self.sockets[output_socket] - + if isinstance(input_socket, str): input_socket = self.sockets[input_socket] - + if not output_socket or not input_socket: - log.warning("Could not connect sockets", output_socket=output_socket, input_socket=input_socket) + log.warning( + "Could not connect sockets", + output_socket=output_socket, + input_socket=input_socket, + ) return - + if output_socket.full_id not in self.edges: self.edges[output_socket.full_id] = [] - + if input_socket.full_id not in self.edges[output_socket.full_id]: self.edges[output_socket.full_id].append(input_socket.full_id) input_socket.source = output_socket - - + def build(self) -> nx.DiGraph: """ Build and return directed acyclic graph """ graph = nx.DiGraph() - + # Add edges between nodes based on socket connections for output_socket_id, input_socket_ids in self.edges.items(): - output_node_id, _ = output_socket_id.split(".",1) - output_node_title = self.nodes[output_node_id].title - + output_node_id, _ = output_socket_id.split(".", 1) + for input_socket_id in input_socket_ids: - input_node_id, _ = input_socket_id.split(".",1) - - input_node_title = self.nodes[input_node_id].title + input_node_id, _ = input_socket_id.split(".", 1) + graph.add_edge(output_node_id, input_node_id) - - + return graph - + def assign_priority(self, node_chain: nx.DiGraph) -> int: """ - Will search for a Stage type + Will search for a Stage type """ - - min_stage = float('inf') + + min_stage = float("inf") for node_id in node_chain: node = self.nodes[node_id] if isinstance(node, Stage): @@ -1701,16 +1719,16 @@ class Graph(NodeBase): async def signal_note_state(self, state: GraphState): if not state.shared.get("creative_mode"): return - + await async_signals.get("nodes_node_state").send(state) state.stack.clear() - + async def node_state_push( - self, - node: NodeBase, - state: GraphState, - inactive: bool = False, - reset: bool = False + self, + node: NodeBase, + state: GraphState, + inactive: bool = False, + reset: bool = False, ) -> NodeState: if not state.shared.get("creative_mode"): return @@ -1718,52 +1736,55 @@ class Graph(NodeBase): if inactive: node_exec.deactivated = True - + if reset: node_exec.start_time = None node_exec.end_time = None - + # push to the end of the stack state.stack.append(node_exec) await self.signal_note_state(state) - + return node_exec - async def node_state_pop( - self, - prev_state: NodeState, - node:NodeBase, - state: GraphState, - error:str=None + self, + prev_state: NodeState, + node: NodeBase, + state: GraphState, + error: str = None, ) -> NodeState: if not state.shared.get("creative_mode"): return - - node_exec = NodeState(node, state, start_time=prev_state.start_time, end_time=prev_state.end_time) + + node_exec = NodeState( + node, state, start_time=prev_state.start_time, end_time=prev_state.end_time + ) node_exec.end_time = time.time() - + if error: node_exec.error = error - + state.stack.append(node_exec) await self.signal_note_state(state) - + return node_exec - + async def node_state_sync_all(self, state: GraphState): for node in self.nodes.values(): await self.node_state_push(node, state, reset=True) - - async def get_nodes(self, fn_filter:Callable=None) -> list[NodeBase]: + + async def get_nodes(self, fn_filter: Callable = None) -> list[NodeBase]: """ Returns a list of nodes in the graph """ if not fn_filter: return list(self.nodes.values()) return [node for node in self.nodes.values() if fn_filter(node)] - - async def get_node(self, fn_filter:Callable=None, require_unique:bool=True) -> NodeBase: + + async def get_node( + self, fn_filter: Callable = None, require_unique: bool = True + ) -> NodeBase: """ Returns a single node from the graph """ @@ -1772,7 +1793,9 @@ class Graph(NodeBase): raise ValueError("Multiple nodes found") return nodes[0] if nodes else None - async def get_nodes_connected_to(self, node: NodeBase, fn_filter:Callable=None) -> list[NodeBase]: + async def get_nodes_connected_to( + self, node: NodeBase, fn_filter: Callable = None + ) -> list[NodeBase]: """ Returns a list of nodes connected to the given node """ @@ -1780,71 +1803,76 @@ class Graph(NodeBase): predecessors = get_ancestors_with_forks(graph, node.id) if not fn_filter: return [self.nodes[node_id] for node_id in predecessors] - return [self.nodes[node_id] for node_id in predecessors if fn_filter(self.nodes[node_id])] - + return [ + self.nodes[node_id] + for node_id in predecessors + if fn_filter(self.nodes[node_id]) + ] + async def execute_to_node( self, stop_at_node: NodeBase, - outer_state: GraphState | None = None, + outer_state: GraphState | None = None, callbacks: list[Callable] = [], emit_state: bool = False, state_values: dict[str, Any] = None, - execute_forks: bool=False, - run_isolated: bool=True + execute_forks: bool = False, + run_isolated: bool = True, ): """Execute the graph in topological order""" graph = self.build() - + # check that node exists if stop_at_node.id not in self.nodes: raise ValueError(f"Node {stop_at_node.id} not found in graph") - + # Get all predecessor nodes including the target node if not execute_forks: predecessors = nx.ancestors(graph, stop_at_node.id) else: predecessors = get_ancestors_with_forks(graph, stop_at_node.id) predecessors.add(stop_at_node.id) - + # Get subgraph of only the nodes we need to execute subgraph = graph.subgraph(predecessors) - + # Check for cycles if not nx.is_directed_acyclic_graph(subgraph): raise ValueError("Graph contains cycles") - + with GraphContext(outer_state, self) as state: - if state_values: state.data.update(state_values) - - await self._execute_inner(subgraph, state, emit_state=emit_state, run_isolated=run_isolated) - + + await self._execute_inner( + subgraph, state, emit_state=emit_state, run_isolated=run_isolated + ) + for callback in callbacks: await callback(state) - + return state - + async def execute( self, outer_state: GraphState | None = None, state_values: dict[str, Any] = None, - callbacks: list[Callable] = [] + callbacks: list[Callable] = [], ): """Execute the graph in topological order""" - + graph = self.build() - + # Check for cycles if not nx.is_directed_acyclic_graph(graph): raise ValueError("Graph contains cycles") - + with GraphContext(outer_state, self) as state: self.reset() - + if state_values: state.data.update(state_values) - + await self.node_state_sync_all(state) await self._execute_inner(graph, state) for callback in self.callbacks: @@ -1853,78 +1881,72 @@ class Graph(NodeBase): await callback(state) async def _execute_inner( - self, - graph: nx.DiGraph, - state: GraphState, - emit_state:bool = True, - run_isolated:bool=False + self, + graph: nx.DiGraph, + state: GraphState, + emit_state: bool = True, + run_isolated: bool = False, ): - verbosity: NodeVerbosity = state.verbosity - + try: # route input socket values to their corresponding Input nodes for node in self.input_nodes: socket = self.get_input_socket(node.get_property("input_name")) if not socket or not socket.source: continue - - socket_value = state.outer.get_node_socket_value(socket.source.node, socket.source.name) - + + socket_value = state.outer.get_node_socket_value( + socket.source.node, socket.source.name + ) + if state.verbosity == NodeVerbosity.VERBOSE: log.debug(f"Setting input value for {node.title} to {socket_value}") - node.set_output_values({ - "value": socket_value - }) - + node.set_output_values({"value": socket_value}) + # for module property nodes we need to set their output socket values # base on the property value for node in self.module_property_nodes: name = node.get_property("property_name") value = self.get_property(name) - node.set_output_values({ - "name": name, - "value": node.cast_value(value) - }) - + node.set_output_values({"name": name, "value": node.cast_value(value)}) + # Separate into weakly connected components (isolated chains) chains = list(nx.weakly_connected_components(graph)) - + # sort chains by priority chains.sort(key=lambda chain: self.assign_priority(chain)) - + for chain in chains: - subgraph = graph.subgraph(chain) - + sorted_nodes = list(nx.topological_sort(subgraph)) - + # check if the final in the chain is _isolated, and if so, skip the chain if self.nodes[sorted_nodes[-1]]._isolated and not run_isolated: continue - + # Execute nodes in topological order for node_id in sorted_nodes: - if self._interrupt: self._interrupt = False break - + node = self.nodes[node_id] if verbosity == NodeVerbosity.VERBOSE: log.debug(f"Running node {node.title} (pre check)") - + if not node.check_is_available(state): if emit_state: await self.node_state_push(node, state, inactive=True) continue - + if verbosity == NodeVerbosity.VERBOSE: log.debug(f"Running node {node.title}") - + if emit_state: node_state = await self.node_state_push(node, state) - + # run node try: await node.run(state) @@ -1937,28 +1959,38 @@ class Graph(NodeBase): except PASSTHROUGH_ERRORS as exc: if emit_state: await self.node_state_pop(node_state, node, state) - + await self.attempt_catch_with_node_error_handler(state, exc) raise exc - except ModuleError as exc: + except ModuleError: if emit_state: - await self.node_state_pop(node_state, node, state, error=traceback.format_exc()) + await self.node_state_pop( + node_state, node, state, error=traceback.format_exc() + ) except Exception as exc: if emit_state: - await self.node_state_pop(node_state, node, state, error=traceback.format_exc()) - + await self.node_state_pop( + node_state, node, state, error=traceback.format_exc() + ) + await self.attempt_catch_with_node_error_handler(state, exc) - + # route Output node values to their corresponding output sockets if isinstance(node, Output): - socket = self.get_output_socket(node.get_property("output_name")) + socket = self.get_output_socket( + node.get_property("output_name") + ) if socket: socket.value = node.get_input_socket("value").value - state.outer.set_node_socket_value(self, socket.name, socket.value) - + state.outer.set_node_socket_value( + self, socket.name, socket.value + ) + if verbosity == NodeVerbosity.VERBOSE: - log.debug(f"Setting output value for {socket.full_id} to {socket.value}") - + log.debug( + f"Setting output value for {socket.full_id} to {socket.value}" + ) + except StopGraphExecution: pass except PASSTHROUGH_ERRORS as exc: @@ -1971,8 +2003,9 @@ class Graph(NodeBase): finally: raise exc - async def attempt_catch_with_node_error_handler(self, state: GraphState, exc: Exception): - + async def attempt_catch_with_node_error_handler( + self, state: GraphState, exc: Exception + ): """Attempt to catch an exception with a node error handler Args: @@ -1982,59 +2015,64 @@ class Graph(NodeBase): Raises: Will re-raise the exception if no error handler is found """ - + error_handlers = await self.get_nodes(lambda n: hasattr(n, "catch")) if not error_handlers: raise exc - + caught = False for error_handler in error_handlers: if await error_handler.catch(state, exc): caught = True break - + if not caught: raise exc - + @base_node_type("core/Loop") class Loop(Graph): - - exit_condition: Callable = pydantic.Field(default_factory=lambda: lambda state: False, exclude=True) - + exit_condition: Callable = pydantic.Field( + default_factory=lambda: lambda state: False, exclude=True + ) + sleep: float = 0.001 - - def __init__(self, **kwargs): + + def __init__(self, **kwargs): super().__init__(**kwargs) - + def setup(self): self.add_input("state") self.add_output("state") - + async def on_loop_start(self, state: GraphState): pass - + async def on_loop_end(self, state: GraphState): pass - + async def on_loop_error(self, state: GraphState, exc: Exception): pass - - async def execute(self, outer_state: GraphState, state_values:dict=None, run_isolated:bool=False): + + async def execute( + self, + outer_state: GraphState, + state_values: dict = None, + run_isolated: bool = False, + ): """Execute the graph in topological order""" graph = self.build() - + if not nx.is_directed_acyclic_graph(graph): raise ValueError("Graph contains cycles") - + with GraphContext(outer_state, self) as state: self.reset() - + if state_values: state.data.update(state_values) - + try: - # Separate into weakly connected components (isolated chains) chains = list(nx.weakly_connected_components(graph)) @@ -2042,38 +2080,40 @@ class Loop(Graph): chains.sort(key=lambda chain: self.assign_priority(chain)) while True: - self.reset_sockets() - + BREAK_LOOP = False - + # LOOP START - + try: await self.on_loop_start(state) except Exception as exc: try: await self.handle_error(state, exc) await self.on_loop_error(state, exc) - log.error("Error in on_loop_start", exc=exc, traceback=traceback.format_exc()) + log.error( + "Error in on_loop_start", + exc=exc, + traceback=traceback.format_exc(), + ) except LoopBreak: BREAK_LOOP = True except LoopContinue: pass except LoopExit: return - + # PROCESS NODE CHAINS - + for chain in chains: - if BREAK_LOOP: break - + await self.node_state_sync_all(state) - + subgraph = graph.subgraph(chain) - + sorted_nodes = list(nx.topological_sort(subgraph)) # check if the final in the chain is _isolated, and if so, skip the chain @@ -2086,30 +2126,42 @@ class Loop(Graph): if self._interrupt: self._interrupt = False raise LoopExit() - + node = self.nodes[node_id] if state.verbosity == NodeVerbosity.VERBOSE: log.debug(f"Running node {node.title} (pre check)") - + if not node.check_is_available(state): - await self.node_state_push(node, state, inactive=True) + await self.node_state_push( + node, state, inactive=True + ) continue - + if state.verbosity == NodeVerbosity.VERBOSE: log.debug(f"Running node {node.title}") - + node_state = await self.node_state_push(node, state) try: await node.run(state) await self.node_state_pop(node_state, node, state) - except PASSTHROUGH_ERRORS as exc: + except PASSTHROUGH_ERRORS: raise - except ModuleError as exc: - await self.node_state_pop(node_state, node, state, error=traceback.format_exc()) - except Exception as exc: - await self.node_state_pop(node_state, node, state, error=traceback.format_exc()) + except ModuleError: + await self.node_state_pop( + node_state, + node, + state, + error=traceback.format_exc(), + ) + except Exception: + await self.node_state_pop( + node_state, + node, + state, + error=traceback.format_exc(), + ) raise - + # Check for loop exit condition if self.exit_condition(state): return @@ -2123,13 +2175,24 @@ class Loop(Graph): except StopGraphExecution: BREAK_LOOP = True break - except (ExitScene, ResetScene, RestartSceneLoop, StopModule) as exc: + except ( + ExitScene, + ResetScene, + RestartSceneLoop, + StopModule, + ): raise except Exception as exc: try: await self.handle_error(state, exc) await self.on_loop_error(state, exc) - log.error("Error in Loop", graph=self.title, graph_cls=self.__class__, exc=exc, traceback=traceback.format_exc()) + log.error( + "Error in Loop", + graph=self.title, + graph_cls=self.__class__, + exc=exc, + traceback=traceback.format_exc(), + ) await asyncio.sleep(1.0) except LoopBreak: BREAK_LOOP = True @@ -2138,38 +2201,45 @@ class Loop(Graph): continue except LoopExit: return - except (ExitScene, ResetScene, RestartSceneLoop, StopModule) as exc: + except ( + ExitScene, + ResetScene, + RestartSceneLoop, + StopModule, + ) as exc: raise - except Exception as exc: - log.error("Error in on_loop_error", exc=traceback.format_exc()) + except Exception: + log.error( + "Error in on_loop_error", + exc=traceback.format_exc(), + ) BREAK_LOOP = True break - - + if BREAK_LOOP: break - + await asyncio.sleep(self.sleep) - + # LOOP END - + await self.on_loop_end(state) - #except Exception as e: + # except Exception as e: # log.error("Error in loop", exc=e, traceback=traceback.format_exc()) # raise finally: for callback in self.callbacks: await callback(state) - -@base_node_type("core/Event") + +@base_node_type("core/Event") class Listen(Graph): """ Listens for an event """ - + _isolated: ClassVar[bool] = True - + class Fields: event_name = PropertyField( name="event_name", @@ -2177,35 +2247,34 @@ class Listen(Graph): type="str", default="", ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: - # If a style is defined in the graph it overrides the default defined_style = super().style if defined_style: return defined_style - + return NodeStyle( node_color="#39382c", title_color="#57532e", - icon="F0BF8", #alpha-e-circle + icon="F0BF8", # alpha-e-circle ) - + def __init__(self, title="Listen", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("event_name", UNRESOLVED) - - async def run(self, state:GraphState): - log.warning("Listen node run directly", node=self) + + async def run(self, state: GraphState): + log.warning("Listen node run directly", node=self) return await super().run(state) - - async def execute_from_event(self, event:object): + + async def execute_from_event(self, event: object): try: - state:GraphState = graph_state.get() + state: GraphState = graph_state.get() except LookupError: # we are outside of the graph state, however we can still attempt # to get the scene's graph state @@ -2215,21 +2284,24 @@ class Listen(Graph): else: log.error("Event node executed outside of active graph state") return - + node_state = await self.node_state_push(self, state) try: await self.execute(state, state_values={"event": event}) except Exception as exc: - await self.node_state_pop(node_state, self, state, error=traceback.format_exc()) + await self.node_state_pop( + node_state, self, state, error=traceback.format_exc() + ) raise exc await self.node_state_pop(node_state, self, state) - + + @base_node_type("core/EventTrigger") class Trigger(Node): """ Triggers an event """ - + class Fields: event_name = PropertyField( name="event_name", @@ -2237,63 +2309,61 @@ class Trigger(Node): type="str", default="", ) - + @property def signals(self): return async_signals - + @property def signal_name(self) -> str | UNRESOLVED: return self.get_input_value("event_name") def __init__(self, title="Trigger Event", **kwargs): super().__init__(title=title, **kwargs) - + def setup_properties(self): self.set_property("event_name", "") - + def setup_required_inputs(self): self.add_input("trigger") - + def setup_optional_inputs(self): self.add_input("event_name", socket_type="str", optional=True) - + def setup_outputs(self): self.add_output("event", socket_type="event") - + def setup(self): self.setup_required_inputs() self.setup_optional_inputs() self.setup_properties() self.setup_outputs() - def make_event_object(self, state:GraphState) -> object: + def make_event_object(self, state: GraphState) -> object: raise NotImplementedError("Event object not defined") - - async def after(self, state:GraphState, event:object): - pass - - async def run(self, state:GraphState): + + async def after(self, state: GraphState, event: object): + pass + + async def run(self, state: GraphState): event_name = self.signal_name - + if not event_name or event_name == UNRESOLVED: log.error("Event name not set") return - + signal = async_signals.get(event_name) if not signal: log.error("Signal not found", event_name=event_name) return - + event = self.make_event_object(state) - + await signal.send(event) - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug(f"Triggered event {event_name}") - - self.set_output_values({ - "event": event - }) - + + self.set_output_values({"event": event}) + await self.after(state, event) diff --git a/src/talemate/game/engine/nodes/core/exception.py b/src/talemate/game/engine/nodes/core/exception.py index b81a8192..dca2672d 100644 --- a/src/talemate/game/engine/nodes/core/exception.py +++ b/src/talemate/game/engine/nodes/core/exception.py @@ -4,6 +4,7 @@ __all__ = [ "ExceptionWrapper", ] + class ExceptionWrapper(pydantic.BaseModel): name: str - message: str \ No newline at end of file + message: str diff --git a/src/talemate/game/engine/nodes/data.py b/src/talemate/game/engine/nodes/data.py index 99c234fa..6d67e152 100644 --- a/src/talemate/game/engine/nodes/data.py +++ b/src/talemate/game/engine/nodes/data.py @@ -3,40 +3,41 @@ import json import structlog import pydantic from .core import ( - Node, + Node, GraphState, - UNRESOLVED, + UNRESOLVED, PropertyField, InputValueError, TYPE_CHOICES, NodeStyle, - NodeVerbosity + NodeVerbosity, ) from .registry import register log = structlog.get_logger("talemate.game.engine.nodes.data") + @register("data/Sort") class Sort(Node): """ Sorts a list of items - + Inputs: - + - items: List of items to sort - sort_keys: List of keys to sort by - reverse: Reverse sort - + Properties: - + - sort_keys: List of keys to sort by - reverse: Reverse sort - + Outputs: - + - sorted_items: Sorted list of items """ - + class Fields: sort_keys = PropertyField( name="sort_keys", @@ -44,93 +45,94 @@ class Sort(Node): type="list", default=UNRESOLVED, ) - + reverse = PropertyField( name="reverse", description="Reverse sort", type="bool", default=False, ) - + def __init__(self, title="Sort", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") - self.add_input("items", socket_type="list") + self.add_input("items", socket_type="list") self.add_input("sort_keys", socket_type=["str", "list"], optional=True) - + self.set_property("reverse", False) self.set_property("sort_keys", UNRESOLVED) - + self.add_output("sorted_items", socket_type="list") - + async def run(self, state: GraphState): items = self.get_input_value("items") - + sort_keys = self.get_input_value("sort_keys") - + if isinstance(sort_keys, str): sort_keys = json.loads(sort_keys) - + if sort_keys != UNRESOLVED and not isinstance(sort_keys, list): log.error("Sort keys must be a list", sort_keys=sort_keys) raise InputValueError(self, "sort_keys", "Sort keys must be a list") - + new_items = [i for i in items] reverse = self.get_property("reverse") if self.is_set(sort_keys) and len(sort_keys) > 0: - new_items.sort(key=lambda x: tuple([getattr(x,k,None) for k in sort_keys]), reverse=reverse) + new_items.sort( + key=lambda x: tuple([getattr(x, k, None) for k in sort_keys]), + reverse=reverse, + ) else: new_items.sort(reverse=reverse) - - self.set_output_values({ - "sorted_items": new_items - }) + + self.set_output_values({"sorted_items": new_items}) + @register("data/JSON") class JSON(Node): """ Node that converts a JSON string to a Python object - + Inputs: - + - json: JSON string - + Outputs: - + - data: Python object (dict or list) """ - + def __init__(self, title="JSON", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("json", socket_type="str") self.add_output("data", socket_type="dict,list") - + async def run(self, state: GraphState): json_string = self.get_input_value("json") - + # convert json string to python object # support list as root object data = json.loads(json_string) - self.set_output_values({ - "data": data - }) + self.set_output_values({"data": data}) + @register("data/Contains") class Contains(Node): """ Checks if a value is in a list or dictionary - + Inputs: - + - object: Object (list, dict, etc.) - if a generator is provided, it will be converted to a list - value: Value - + Outputs: - + - contains: True if value is in object, False otherwise """ @@ -144,59 +146,57 @@ class Contains(Node): def __init__(self, title="Contains", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("object", socket_type="any") self.add_input("value", socket_type="any") self.set_property("value", UNRESOLVED) - + self.add_output("contains", socket_type="bool") - + async def run(self, state: GraphState): object = self.get_input_value("object") value = self.get_input_value("value") - + # If object is a generator, convert it to a list if hasattr(object, "__iter__") and not isinstance(object, (dict, list, str)): object = list(object) - + contains = False - + # Check if value is in object if isinstance(object, dict): contains = value in object elif isinstance(object, (list, str)) or hasattr(object, "__contains__"): contains = value in object - + if state.verbosity >= NodeVerbosity.NORMAL: log.debug("Contains check", object=object, value=value, contains=contains) - - self.set_output_values({ - "contains": contains - }) + + self.set_output_values({"contains": contains}) + @register("data/DictGet") class DictGet(Node): - """ Retrieves a value from a dictionary - + Inputs: - + - dict: Dictionary - key: Key Properties: - + - key: Key - + Outputs: - + - value: Value """ - + class Fields: key = PropertyField( name="key", @@ -204,46 +204,43 @@ class DictGet(Node): type="str", default=UNRESOLVED, ) - + def __init__(self, title="Dict Get", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("dict", socket_type="dict") self.add_input("key", socket_type="str") self.add_output("value", socket_type="any") self.add_output("key", socket_type="str") - + self.set_property("key", UNRESOLVED) - + async def run(self, state: GraphState): data = self.get_input_value("dict") key = self.get_input_value("key") - + value = data.get(key) - - self.set_output_values({ - "value": value, - "key": key - }) + + self.set_output_values({"value": value, "key": key}) + @register("data/DictPop") class DictPop(Node): - """ Pops a value from a dictionary - + Inputs: - + - dict: Dictionary - key: Key - + Properties: - + - key: Key - + Outputs: - + - dict: Dictionary - value: Value - key: Key @@ -256,57 +253,53 @@ class DictPop(Node): type="str", default=UNRESOLVED, ) - + def __init__(self, title="Dict Pop", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("dict", socket_type="dict") self.add_input("key", socket_type="str") self.add_output("dict", socket_type="dict") self.add_output("value", socket_type="any") self.add_output("key", socket_type="str") - + self.set_property("key", UNRESOLVED) - + async def run(self, state: GraphState): data = self.get_input_value("dict") key = self.get_input_value("key") - + value = data.pop(key, None) - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Dict pop", key=key, value=value) - - self.set_output_values({ - "dict": data, - "value": value, - "key": key - }) - + + self.set_output_values({"dict": data, "value": value, "key": key}) + + @register("data/DictSet") class DictSet(Node): - """ Set a value in a dictionary - + Inputs: - + - dict: Dictionary - if not provided, a new dictionary will be created - key: Key - value: Value - + Properties: - + - key: Key - + Outputs: - + - dict: Dictionary - key: Key - value: Value """ - + class Fields: key = PropertyField( name="key", @@ -314,56 +307,53 @@ class DictSet(Node): type="str", default=UNRESOLVED, ) - + def __init__(self, title="Dict Set", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("dict", socket_type="dict", optional=True) self.add_input("key", socket_type="str", optional=True) self.add_input("value", socket_type="any") - + self.add_output("dict", socket_type="dict") self.add_output("key", socket_type="str") self.add_output("value", socket_type="any") - + self.set_property("key", UNRESOLVED) - + async def run(self, state: GraphState): data = self.get_input_value("dict") - + if not self.is_set(data): data = {} - + key = self.get_input_value("key") value = self.get_input_value("value") - + data[key] = value - - self.set_output_values({ - "dict": data, - "key": key, - "value": value - }) - + + self.set_output_values({"dict": data, "key": key, "value": value}) + + @register("data/MakeDict") class MakeDict(Node): """ Creates a new empty dictionary - + Inputs: - + - state: Graph state - + Properties: - data: Data to initialize the dictionary with - + Outputs: - + - dict: Dictionary """ - + class Fields: data = PropertyField( name="data", @@ -371,56 +361,55 @@ class MakeDict(Node): type="dict", default={}, ) - + def __init__(self, title="Make Dict", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state", optional=True) - + self.set_property("data", {}) - + self.add_output("dict", socket_type="dict") - + async def run(self, state: GraphState): new_dict = self.get_property("data") - - self.set_output_values({ - "dict": new_dict - }) + + self.set_output_values({"dict": new_dict}) + @register("data/Get") class Get(Node): """ Get a value from an object using getattr - + Can be used on dictionaries as well. - + Inputs: - + - object: Object - attribute: Attribute - + Properties: - + - attribute: Attribute - + Outputs: - + - value: Value - attribute: Attribute - object: Object """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#44552f", - icon="F0552", # upload - auto_title="GET obj.{attribute}" + icon="F0552", # upload + auto_title="GET obj.{attribute}", ) - + class Fields: attribute = PropertyField( name="attribute", @@ -428,63 +417,64 @@ class Get(Node): type="str", default=UNRESOLVED, ) - + def __init__(self, title="Get", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("object", socket_type="any") self.add_input("attribute", socket_type="str") - + self.set_property("attribute", UNRESOLVED) - + self.add_output("value", socket_type="any") self.add_output("attribute", socket_type="str") self.add_output("object", socket_type="any") - + async def run(self, state: GraphState): obj = self.get_input_value("object") attribute = self.get_input_value("attribute") - + if isinstance(obj, dict): value = obj.get(attribute) elif isinstance(obj, list): try: index = int(attribute) except (ValueError, TypeError): - raise InputValueError(self, "attribute", "Attribute must be an integer if object is a list") + raise InputValueError( + self, + "attribute", + "Attribute must be an integer if object is a list", + ) try: value = obj[index] except IndexError: value = UNRESOLVED else: value = getattr(obj, attribute, None) - - self.set_output_values({ - "value": value, - "attribute": attribute, - "object": obj - }) + + self.set_output_values({"value": value, "attribute": attribute, "object": obj}) + @register("data/Set") class Set(Node): """ Set a value on an object using setattr - + Can be used on dictionaries as well. - + Inputs: - + - object: Object - attribute: Attribute - value: Value - + Properties: - + - attribute: Attribute - + Outputs: - + - object: Object - attribute: Attribute - value: Value @@ -495,10 +485,10 @@ class Set(Node): def style(self) -> NodeStyle: return NodeStyle( title_color="#2e4657", - icon="F01DA", # upload - auto_title="SET obj.{attribute}" + icon="F01DA", # upload + auto_title="SET obj.{attribute}", ) - + class Fields: attribute = PropertyField( name="attribute", @@ -506,188 +496,185 @@ class Set(Node): type="str", default=UNRESOLVED, ) - - + def __init__(self, title="Set", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("object", socket_type="any") self.add_input("attribute", socket_type="str") self.add_input("value", socket_type="any") - + self.set_property("attribute", UNRESOLVED) - + self.add_output("object", socket_type="any") self.add_output("attribute", socket_type="str") self.add_output("value", socket_type="any") - + async def run(self, state: GraphState): obj = self.get_input_value("object") attribute = self.get_input_value("attribute") value = self.get_input_value("value") - + if isinstance(obj, dict): obj[attribute] = value elif isinstance(obj, list): try: index = int(attribute) except (ValueError, IndexError): - raise InputValueError(self, "attribute", "Attribute must be an integer if object is a list") + raise InputValueError( + self, + "attribute", + "Attribute must be an integer if object is a list", + ) obj[index] = value else: setattr(obj, attribute, value) - - self.set_output_values({ - "object": obj, - "attribute": attribute, - "value": value - }) + + self.set_output_values({"object": obj, "attribute": attribute, "value": value}) + @register("data/MakeList") class MakeList(Node): """ Creates a new empty list - + Inputs: - + - state: Graph state - + Outputs: - + - list: List """ - + class Fields: item_type = PropertyField( name="item_type", description="Type of items in the list", type="str", default="any", - generate_choices=lambda: TYPE_CHOICES + generate_choices=lambda: TYPE_CHOICES, ) - + items = PropertyField( name="items", description="Initial items in the list", type="list", default=[], ) - + def __init__(self, title="Make List", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state", optional=True) self.add_input("item_type", socket_type="str", optional=True) - + self.set_property("item_type", "any") self.set_property("items", []) - + self.add_output("list", socket_type="list") - + async def run(self, state: GraphState): item_type = self.get_input_value("item_type") if item_type == UNRESOLVED: item_type = self.get_property("item_type") - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Creating new list", item_type=item_type) - + # Create a new empty list new_list = self.get_property("items") - - self.set_output_values({ - "list": new_list - }) + + self.set_output_values({"list": new_list}) + @register("data/ListAppend") class ListAppend(Node): """ Appends an item to a list - + Inputs: - + - list: List - item: Item - + Outputs: - + - list: List - item: Item """ - + def __init__(self, title="List Append", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("list", socket_type="list", optional=True) self.add_input("item", socket_type="any") - + self.add_output("list", socket_type="list") self.add_output("item", socket_type="any") - + async def run(self, state: GraphState): list_obj = self.get_input_value("list") item = self.get_input_value("item") - + if list_obj == UNRESOLVED or list_obj is None: list_obj = [] - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Appending item to list", list_length=len(list_obj), item=item) - - #list_node = self.get_input_socket("list").source.node - + + # list_node = self.get_input_socket("list").source.node + # validate item type - #if list_node.get_property("item_type") - + # if list_node.get_property("item_type") + # Append the item to the list list_obj.append(item) - - self.set_output_values({ - "list": list_obj, - "item": item - }) + + self.set_output_values({"list": list_obj, "item": item}) + @register("data/ListRemove") class ListRemove(Node): """ Removes an item from a list - + Inputs: - + - list: List - item: Item - + Outputs: - + - list: List - item: Item - removed: True if item was removed, False if not """ - + def __init__(self, title="List Remove", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("list", socket_type="list") self.add_input("item", socket_type="any") - + self.add_output("list", socket_type="list") self.add_output("item", socket_type="any") self.add_output("removed", socket_type="bool") - + async def run(self, state: GraphState): list_obj = self.get_input_value("list") item = self.get_input_value("item") - + if list_obj == UNRESOLVED or list_obj is None: raise InputValueError(self, "list", "List must be provided") - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Removing item from list", list_length=len(list_obj), item=item) - + # Try to remove the item from the list removed = False try: @@ -700,77 +687,73 @@ class ListRemove(Node): removed = False if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Item not found in list", item=item) - - self.set_output_values({ - "list": list_obj, - "item": item, - "removed": removed - }) + + self.set_output_values({"list": list_obj, "item": item, "removed": removed}) + @register("data/Length") class Length(Node): """ Gets the length of an iterable - + Inputs: - + - object: Object (list, dict, etc.) - + Outputs: - + - length: Length of the object (number of items) """ - + def __init__(self, title="Length", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("object") - + self.add_output("length", socket_type="int") - + async def run(self, state: GraphState): obj = self.get_input_value("object") - + # if object is generator convert to list if hasattr(obj, "__iter__") and not isinstance(obj, (dict, list)): obj = list(obj) - - self.set_output_values({ - "length": len(obj) - }) -@register("data/SelectItem") + self.set_output_values({"length": len(obj)}) + + +@register("data/SelectItem") class SelectItem(Node): """ Node that takes in a list of items and selects one based on the selection function - + - random - cycle - sorted_cycle - + Inputs: - + - items: List of items - except: Item to exclude from selection - + Properties: - + - index: Index of item to select - selection_function: Selection function - cycle_index: Cycle index (ephemeral, read-only) - + Outputs: - + - selected_item: Selected item """ - + class Fields: cycle_index = PropertyField( - name="cycle_index", - description="cycle index", + name="cycle_index", + description="cycle index", type="int", - ephemeral=True, + ephemeral=True, default=0, readonly=True, ) @@ -785,52 +768,58 @@ class SelectItem(Node): description="Selection function", type="str", default="cycle", - choices=["random", "cycle", "sorted_cycle", "direct"] + choices=["random", "cycle", "sorted_cycle", "direct"], ) - + def __init__(self, title="Select Item", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("items", socket_type="list") - + self.add_input("except", socket_type="any", optional=True) - + self.add_output("selected_item", socket_type="any") - + self.set_property("index", 0) self.set_property("cycle_index", 0) self.set_property("selection_function", "cycle") - + async def run(self, state: GraphState): items = self.get_input_value("items") index = self.get_property("index") selection_function = self.get_property("selection_function") - + # Determine which state object to use - state_data = state.outer.data if getattr(state, 'outer', None) else state.data - + state_data = state.outer.data if getattr(state, "outer", None) else state.data + # Initialize cycle_index in state if it doesn't exist, using self.id in the key cycle_key = f"{self.id}_cycle_index" if cycle_key not in state_data: state_data[cycle_key] = 0 - + except_items = self.get_input_value("except") - + if not isinstance(except_items, list) and except_items is not None: except_items = [except_items] - + items = items.copy() - + if except_items: items = [i for i in items if i not in except_items] - + if state_data[cycle_key] >= len(items): state_data[cycle_key] = 0 - + if state.verbosity >= NodeVerbosity.VERBOSE: - log.debug("Select item", items=items, except_items=except_items, selection_function=selection_function, cycle_index=state_data[cycle_key]) - + log.debug( + "Select item", + items=items, + except_items=except_items, + selection_function=selection_function, + cycle_index=state_data[cycle_key], + ) + if selection_function == "direct": try: selected_item = items[index] @@ -851,7 +840,5 @@ class SelectItem(Node): items_copy.sort() selected_item = items_copy[state_data[cycle_key]] state_data[cycle_key] = (state_data[cycle_key] + 1) % len(items) - - self.set_output_values({ - "selected_item": selected_item - }) + + self.set_output_values({"selected_item": selected_item}) diff --git a/src/talemate/game/engine/nodes/event.py b/src/talemate/game/engine/nodes/event.py index 1d58de6e..9ba38355 100644 --- a/src/talemate/game/engine/nodes/event.py +++ b/src/talemate/game/engine/nodes/event.py @@ -5,9 +5,8 @@ from .core import Listen, Node, Graph, GraphState, NodeVerbosity, PropertyField from .registry import register from talemate.agents.registry import get_agent_types from talemate.agents.base import Agent -from talemate.emit import emit, Emission -from talemate.emit.signals import handlers -from talemate.util.colors import COLOR_NAMES, COLOR_MAP +from talemate.emit import emit +from talemate.util.colors import COLOR_NAMES from talemate.context import active_scene from talemate.game.engine.api.schema import StatusEnum @@ -23,63 +22,66 @@ log = structlog.get_logger("talemate.game.engine.nodes.event") def collect_listeners(graph: Graph) -> dict[str, list["Listen"]]: """ Does a deep search of the graph to find all Listen nodes - + Args: graph: Graph to search - + Returns: dict[str, list["Listen"]]: A dictionary of event names to Listen nodes """ event_listeners = {} for node in graph.nodes.values(): - if isinstance(node, Listen): event_name = node.get_property("event_name") - + if not event_name: log.warning("Listen node has no event name", node=node) continue - + event_listeners.setdefault(event_name, []).append(node) elif isinstance(node, Graph): event_listeners.update(collect_listeners(node)) - + return event_listeners + def connect_listeners(graph: Graph, state: GraphState, disconnect: bool = False): """ Connects all Listen nodes in the graph to the event bus - + Args: graph: Graph to search """ - + event_listeners = collect_listeners(graph) - + for event_name, listeners in event_listeners.items(): for listener in listeners: signal = signals.get(event_name) if not signal: log.warning("Event not found", event_name=event_name) continue - + if state.verbosity == NodeVerbosity.NORMAL: - log.debug("Connecting listener", listener=listener, event_name=event_name) - + log.debug( + "Connecting listener", listener=listener, event_name=event_name + ) + if disconnect: signal.disconnect(listener.execute_from_event) - + signal.connect(listener.execute_from_event) + def disconnect_listeners(graph: Graph, state: GraphState): """ Disconnects all Listen nodes in the graph from the event bus - + Args: graph: Graph to search """ - + event_listeners = collect_listeners(graph) for event_name, listeners in event_listeners.items(): for listener in listeners: @@ -87,10 +89,12 @@ def disconnect_listeners(graph: Graph, state: GraphState): if not signal: log.warning("Event not found", event_name=event_name) continue - + if state.verbosity == NodeVerbosity.NORMAL: - log.debug("Disconnecting listener", listener=listener, event_name=event_name) - + log.debug( + "Disconnecting listener", listener=listener, event_name=event_name + ) + signal.disconnect(listener.execute_from_event) @@ -98,43 +102,43 @@ def disconnect_listeners(graph: Graph, state: GraphState): class State(Graph): """ Returns the current event object when inside a Listen node module. - + Outputs: - + - event: The current event object """ - + def __init__(self, title="Event", **kwargs): super().__init__(title=title, **kwargs) - def setup(self): self.add_output("event", socket_type="event") - - + async def run(self, state: GraphState): - self.set_output_values({ - "event": state.data.get("event"), - }) - - + self.set_output_values( + { + "event": state.data.get("event"), + } + ) + + @register("event/EmitStatus") class EmitStatus(Node): """ Emits a status message - + Inputs: - + - message: The message text to emit - status: The status of the message - as_scene_message: Whether to emit the message as a scene message (optional) - - + + Outputs: - + - emitted: Whether the message was emitted (True) or not (False) """ - + class Fields: message = PropertyField( name="message", @@ -142,7 +146,7 @@ class EmitStatus(Node): type="str", default="", ) - + status = PropertyField( name="status", description="The status of the message", @@ -150,61 +154,64 @@ class EmitStatus(Node): default="info", generate_choices=lambda: sorted(list(StatusEnum.__members__.keys())), ) - + as_scene_message = PropertyField( name="as_scene_message", description="Whether to emit the message as a scene message", type="bool", default=False, ) - + def __init__(self, title="Emit Status", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("message", socket_type="str", optional=True) self.add_input("status", socket_type="str", optional=True) self.add_input("as_scene_message", socket_type="bool", optional=True) - + self.set_property("message", "") self.set_property("status", "info") self.set_property("as_scene_message", False) - + self.add_output("emitted", socket_type="bool") - + async def run(self, state: GraphState): message_text = self.require_input("message") status = self.require_input("status") as_scene_message = self.get_input_value("as_scene_message") scene = active_scene.get() - + data = {} - + if as_scene_message is True: data["as_scene_message"] = True - + emit("status", message=message_text, status=status, scene=scene, data=data) - - self.set_output_values({ - "emitted": True, - }) - + + self.set_output_values( + { + "emitted": True, + } + ) + + @register("event/EmitSystemMessage") class EmitSystemMessage(EmitStatus): """ Emits a system message - + Inputs: - + - state: The graph state - message: The message text to emit - - + + Outputs: - + - state: The graph state """ - + class Fields: message_title = PropertyField( name="message_title", @@ -229,7 +236,7 @@ class EmitSystemMessage(EmitStatus): name="icon", description="The icon of the message", type="str", - default="mdi-information", # information + default="mdi-information", # information ) display = PropertyField( name="display", @@ -244,10 +251,10 @@ class EmitSystemMessage(EmitStatus): type="bool", default=False, ) - + def __init__(self, title="Emit System Message", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("message", socket_type="str", optional=True) @@ -259,7 +266,7 @@ class EmitSystemMessage(EmitStatus): self.set_property("display", "text") self.set_property("as_markdown", False) self.add_output("state") - + async def run(self, state: GraphState): state = self.get_input_value("state") message = self.require_input("message") @@ -268,93 +275,106 @@ class EmitSystemMessage(EmitStatus): icon = self.get_property("icon") display = self.get_property("display") as_markdown = self.get_property("as_markdown") - emit("system", message=message, meta={ - "color": font_color, - "icon": icon, - "title": message_title, - "display": display, - "as_markdown": as_markdown, - }) - self.set_output_values({ - "state": state, - }) - + emit( + "system", + message=message, + meta={ + "color": font_color, + "icon": icon, + "title": message_title, + "display": display, + "as_markdown": as_markdown, + }, + ) + self.set_output_values( + { + "state": state, + } + ) + + @register("event/EmitStatusConditional") class EmitStatusConditional(EmitStatus): """ Emits a status message if a condition is met - + Inputs: - + - state: The graph state - message: The message text to emit - status: The status of the message - as_scene_message: Whether to emit the message as a scene message (optional) - + Outputs: - + - state: The graph state - emitted: Whether the message was emitted (True) or not (False) """ - + def __init__(self, title="Emit Status (Conditional)", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") super().setup() - + async def run(self, state: GraphState): _state = self.get_input_value("state") await super().run(state) - self.set_output_values({ - "state": _state, - }) - + self.set_output_values( + { + "state": _state, + } + ) + + @register("event/EmitSceneStatus") class EmitSceneStatus(Node): """ Emits the scene status object to the UX - + Inputs: - + - state: The graph state - + Outputs: - + - state: The scene status object """ - + def __init__(self, title="Emit Scene Status", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") - + async def run(self, state: GraphState): _state = self.get_input_value("state") scene = active_scene.get() scene.emit_status() - self.set_output_values({ - "state": _state, - }) - + self.set_output_values( + { + "state": _state, + } + ) + + @register("event/EmitAgentMessage") class EmitAgentMessage(Node): """ Emits an agent message - + EXAMPLE - emit("agent_message", + emit("agent_message", message=message, data={ "uuid": str(uuid.uuid4()), "agent": "editor", "header": "Removed repetition", "color": "highlight4", - }, + }, meta={ "action": "revision_dedupe", "similarity": dedupe['similarity'], @@ -363,24 +383,23 @@ class EmitAgentMessage(Node): }, websocket_passthrough=True ) - - + + Inputs: - + - state: The graph state - message: The message text to emit - agent: The agent - header: The header of the message - color: The color of the message - meta: The meta data of the message - + Outputs: - + - emitted: Whether the message was emitted (True) or not (False) """ - - + class Fields: message = PropertyField( name="message", @@ -388,23 +407,23 @@ class EmitAgentMessage(Node): type="str", default="", ) - + agent = PropertyField( name="agent", type="str", default="", description="The name of the agent to get the client for", choices=[], - generate_choices=lambda: get_agent_types() + generate_choices=lambda: get_agent_types(), ) - + header = PropertyField( name="header", description="The header of the message", type="str", default="", ) - + message_color = PropertyField( name="message_color", description="The color of the message", @@ -412,60 +431,63 @@ class EmitAgentMessage(Node): default="grey", generate_choices=lambda: COLOR_NAMES, ) - + meta = PropertyField( name="meta", description="The meta data of the message", type="dict", default={}, ) - + def __init__(self, title="Emit Agent Message", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") - + self.add_input("message", socket_type="str", optional=True) self.add_input("agent", socket_type="agent,str", optional=True) self.add_input("header", socket_type="str", optional=True) self.add_input("message_color", socket_type="str", optional=True) self.add_input("meta", socket_type="dict", optional=True) - + self.set_property("message", "") self.set_property("agent", "") self.set_property("header", "") self.set_property("message_color", "grey") self.set_property("meta", {}) - + self.add_output("emitted", socket_type="bool") - + async def run(self, state: GraphState): message = self.require_input("message") agent = self.require_input("agent") header = self.require_input("header") message_color = self.require_input("message_color") meta = self.require_input("meta") - + if isinstance(agent, Agent): agent_name = agent.name else: agent_name = agent - + data = { "uuid": str(uuid.uuid4()), "agent": agent_name, "header": header, "color": message_color, } - - emit("agent_message", + + emit( + "agent_message", message=message, data=data, meta=meta, - websocket_passthrough=True + websocket_passthrough=True, + ) + + self.set_output_values( + { + "emitted": True, + } ) - - self.set_output_values({ - "emitted": True, - }) diff --git a/src/talemate/game/engine/nodes/focal.py b/src/talemate/game/engine/nodes/focal.py index 2069e377..af24709d 100644 --- a/src/talemate/game/engine/nodes/focal.py +++ b/src/talemate/game/engine/nodes/focal.py @@ -3,24 +3,24 @@ This set of nodes represents integration with talemate's FOCAL system for function orchestration with abstraction of creative tasks. """ -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING import structlog from talemate.game.engine.nodes.core import ( PropertyField, - Node, - GraphState, + Node, + GraphState, InputValueError, - register, + register, UNRESOLVED, - TYPE_CHOICES as SOCKET_TYPES + TYPE_CHOICES as SOCKET_TYPES, ) from talemate.context import active_scene from talemate.prompts.base import PrependTemplateDirectories -from talemate.game.engine.nodes.run import FunctionWrapper, FunctionArgument +from talemate.game.engine.nodes.run import FunctionWrapper import talemate.game.focal as focal if TYPE_CHECKING: - from talemate.tale_mate import Scene, Character + from talemate.tale_mate import Scene __all__ = [ "Focal", @@ -38,12 +38,12 @@ SOCKET_TYPES.extend( ] ) + @register("focal/Focal") class Focal(Node): - """ Main node for calling AI functions using the FOCAL system. - + Inputs: - state: The current graph state - template: The prompt template name; This template will be used to generate the prompt that facilitates the AI function call(s) @@ -51,17 +51,17 @@ class Focal(Node): - agent: The agent to use for the AI function call - template_vars: A dictionary of variables to use in the template - max_calls: The maximum number of calls to make - + Properties: - template: The prompt template name - max_calls: The maximum number of calls to make - + Outputs: - state: The current graph state - calls: The list of calls made - response: The raw response from the processed prompt """ - + class Fields: template = PropertyField( name="template", @@ -69,7 +69,7 @@ class Focal(Node): type="str", default=UNRESOLVED, ) - + max_calls = PropertyField( name="max_calls", description="The maximum number of calls to make", @@ -79,7 +79,7 @@ class Focal(Node): min=1, max=10, ) - + retries = PropertyField( name="retries", description="The number of retries to make", @@ -92,28 +92,26 @@ class Focal(Node): def __init__(self, title="AI Function Calling", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): - self.add_input("state") self.add_input("template", socket_typoe="str") self.add_input("callbacks", socket_type="list") self.add_input("agent", socket_type="agent") self.add_input("template_vars", socket_type="dict", optional=True) self.add_input("max_calls", socket_type="int", optional=True) - + self.set_property("template", UNRESOLVED) self.set_property("max_calls", 1) self.set_property("retries", 0) - + self.add_output("state") self.add_output("calls", socket_type="list") self.add_output("response", socket_type="str") - - - async def run(self, state:GraphState): - scene:"Scene" = active_scene.get() - + + async def run(self, state: GraphState): + scene: "Scene" = active_scene.get() + in_state = self.get_input_value("state") template = self.get_input_value("template") callbacks = self.get_input_value("callbacks") @@ -121,25 +119,25 @@ class Focal(Node): template_vars = self.get_input_value("template_vars") max_calls = self.require_number_input("max_calls", types=(int,)) retries = self.require_number_input("retries", types=(int,)) - + if not hasattr(agent, "client"): raise InputValueError( self, "agent", - "The specified agent does not have an appropriate LLM client configured." + "The specified agent does not have an appropriate LLM client configured.", ) - + for callback in callbacks: if not isinstance(callback, focal.Callback): raise InputValueError( self, "callbacks", - f"Callback must be a focal.Callback instance. Got {type(callback)} instead." + f"Callback must be a focal.Callback instance. Got {type(callback)} instead.", ) - + if template_vars: template_vars.pop("scene", None) - + focal_handler = focal.Focal( agent.client, callbacks=callbacks, @@ -150,39 +148,41 @@ class Focal(Node): "scene_loop": state.shared.get("scene_loop", {}), "local": state.data, }, - **template_vars + **template_vars, ) - + async def process(*args, **kwargs): return await focal_handler.request(template) + process.__name__ = self.title.replace(" ", "_").lower() - + with PrependTemplateDirectories([scene.template_dir]): response = await agent.delegate(process) - - self.set_output_values({ - "state": in_state, - "calls": focal_handler.state.calls, - "response": response, - }) - - + + self.set_output_values( + { + "state": in_state, + "calls": focal_handler.state.calls, + "response": response, + } + ) + + @register("focal/Callback") class Callback(Node): - """ Defines an AI function callback for use with the FOCAL system. - + Inputs: - fn: The function to call (Returned from an GetFunction node) - + Properties: - name: The name of the callback as the AI will see it - + Outputs: - callback: The focal.Callback instance """ - + class Fields: name = PropertyField( name="name", @@ -190,83 +190,82 @@ class Callback(Node): type="str", default="my_function", ) - + allow_multiple_calls = PropertyField( name="allow_multiple_calls", description="Whether the function can be called multiple times", type="bool", default=False, ) - + def __init__(self, title="AI Function Callback", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): - - #self.add_input("arguments", socket_type="list") + # self.add_input("arguments", socket_type="list") self.add_input("fn", socket_type="function") self.set_property("name", "my_function") self.set_property("allow_multiple_calls", False) - + self.add_output("callback", socket_type="focal/callback") - - async def run(self, state:GraphState): - + + async def run(self, state: GraphState): fn = self.get_input_value("fn") - + if not isinstance(fn, FunctionWrapper): raise InputValueError( - self, - "fn", - f"Function must be FunctionWrapper. Got {type(fn)} instead." + self, "fn", f"Function must be FunctionWrapper. Got {type(fn)} instead." ) - + fn_arg_nodes = await fn.get_argument_nodes() - + arguments = [ focal.Argument( name=node.get_property("name"), type=node.get_property("typ"), - ) for node in fn_arg_nodes + ) + for node in fn_arg_nodes ] - + callback = focal.Callback( name=self.get_property("name"), arguments=arguments, fn=fn, multiple=self.get_property("allow_multiple_calls"), ) - + log.debug("Callback created", callback=callback, fn=fn) - - self.set_output_values({ - "callback": callback, - }) + + self.set_output_values( + { + "callback": callback, + } + ) + @register("focal/ProcessCall") class ProcessCall(Node): - """ Process the AI function call result. - + Inputs: - + - calls: The list of calls made (focal.Call instances) - + Properties: - + - name: The name of the call to process - + Outputs: - + - name: The name of the call - arguments: The arguments of the call - result: The result of the call - uid: The UID of the call - called: Whether the call was made (if this is False, likely something went wrong) """ - + class Fields: name = PropertyField( name="name", @@ -274,33 +273,34 @@ class ProcessCall(Node): type="str", default=UNRESOLVED, ) - - + def __init__(self, title="Process AI Function Call", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("calls", socket_type="list") - + self.set_property("name", UNRESOLVED) - + self.add_output("name", socket_type="str") self.add_output("arguments", socket_type="dict") self.add_output("result", socket_type="any") self.add_output("uid", socket_type="str") self.add_output("called", socket_type="bool") - - async def run(self, state:GraphState): + + async def run(self, state: GraphState): name = self.require_input("name") calls = self.require_input("calls") - + for call in calls: if call.name == name: - self.set_output_values({ - "name": call.name, - "arguments": call.arguments, - "result": call.result, - "uid": call.uid, - "called": call.called, - }) - break \ No newline at end of file + self.set_output_values( + { + "name": call.name, + "arguments": call.arguments, + "result": call.result, + "uid": call.uid, + "called": call.called, + } + ) + break diff --git a/src/talemate/game/engine/nodes/history.py b/src/talemate/game/engine/nodes/history.py index ce421e78..9578d51b 100644 --- a/src/talemate/game/engine/nodes/history.py +++ b/src/talemate/game/engine/nodes/history.py @@ -1,193 +1,197 @@ import structlog -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING from .core import ( - Loop, - Node, - Entry, - GraphState, - UNRESOLVED, - LoopBreak, - LoopContinue, - NodeVerbosity, - InputValueError, + Node, + GraphState, + UNRESOLVED, + InputValueError, PropertyField, - Trigger, ) from .registry import register -from .event import connect_listeners, disconnect_listeners -import talemate.events as events -from talemate.emit import wait_for_input, emit -from talemate.exceptions import ActedAsCharacter, AbortWaitForInput -from talemate.context import active_scene, InteractionState +from talemate.emit import emit +from talemate.context import active_scene from talemate.scene_message import MESSAGES import talemate.scene_message as scene_message from talemate.history import character_activity if TYPE_CHECKING: - from talemate.tale_mate import Scene, Character + from talemate.tale_mate import Scene log = structlog.get_logger("talemate.game.engine.nodes.history") + @register("scene/history/Push") class PushHistory(Node): """ Push a message to the scene history at the lowest (e.g., dialogue) layer - + This will emit the message to the the sreen as part of the ongoing scene - + Inputs: - + - message: The message to push - + Properties: - + - emit_message: Whether to emit the message to the screen - + Outputs: - + - message: The message object """ + class Fields: emit_message = PropertyField( name="emit_message", description="Emit the message to the screen", type="bool", - default=True + default=True, ) - + def __init__(self, title="Push History", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("message", socket_type="message_object") - + self.set_property("emit_message", True) - + self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() message = self.get_input_value("message") emit_message = self.get_property("emit_message") - + if not isinstance(message, scene_message.SceneMessage): - raise InputValueError(self, "message", "Input is not a SceneMessage instance") - + raise InputValueError( + self, "message", "Input is not a SceneMessage instance" + ) + scene.push_history(message) - + if emit_message: if isinstance(message, scene_message.CharacterMessage): - emit("character", message, character=scene.get_character(message.character_name)) + emit( + "character", + message, + character=scene.get_character(message.character_name), + ) elif isinstance(message, scene_message.NarratorMessage): emit("narrator", message) elif isinstance(message, scene_message.ContextInvestigationMessage): emit("context_investigation", message) elif isinstance(message, scene_message.DirectorMessage): - emit("director", message, character=scene.get_character(message.character_name) if message.character_name else None) + emit( + "director", + message, + character=scene.get_character(message.character_name) + if message.character_name + else None, + ) + + self.set_output_values({"message": message}) + - self.set_output_values({ - "message": message - }) - @register("scene/history/Pop") class PopHistory(Node): """ Pop a message from the scene history - + Inputs: - + - message: The message to pop - + Properties: - + - emit_removal: Whether to emit the removal of the message - + Outputs: - + - message: The message object """ - + class Fields: emit_removal = PropertyField( name="emit_removal", description="Emit the removal of the message", type="bool", - default=True + default=True, ) - + def __init__(self, title="Pop History", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("message", socket_type="message_object") - + self.set_property("emit_removal", True) - + self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() emit_removal = self.get_property("emit_removal") message = self.get_input_value("message") - + if not isinstance(message, scene_message.SceneMessage): - raise InputValueError(self, "message", "Input is not a SceneMessage instance") - + raise InputValueError( + self, "message", "Input is not a SceneMessage instance" + ) + scene.pop_message(message) - + if emit_removal: - emit("remove_message", "", id=message.id) - - self.set_output_values({ - "message": message - }) + emit("remove_message", "", id=message.id) + + self.set_output_values({"message": message}) @register("scene/history/LastMessageOfType") class LastMessageOfType(Node): """ - Get the last message of a certain type from the history with + Get the last message of a certain type from the history with some basic filtering. - + Inputs: - + - message_type: The type of message to get (or a list of types) - filters: filter the messages by property values - + Properties: - + - message_type: The type of message to get (or a list of types) - max_iterations: The maximum number of iterations to go back - filters: filter the messages by property values - stop_on_time_passage: Stop when a time passage message is encountered - + Outputs: - + - message: The message object """ - + class Fields: max_iterations = PropertyField( name="max_iterations", description="The maximum number of iterations to go back", type="int", - default=100 + default=100, ) - + filters = PropertyField( name="filters", description="Filter the messages by property values", type="dict", - default={} + default={}, ) - + stop_on_time_passage = PropertyField( name="stop_on_time_passage", description="Stop when a time passage message is encountered", type="bool", default=False, ) - + message_type = PropertyField( name="message_type", description="The type of message to get (or a list of types)", @@ -195,110 +199,114 @@ class LastMessageOfType(Node): default="character", choices=list(MESSAGES.keys()), ) - + def __init__(self, title="Last Message of Type", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("message_type", socket_type="str,list") self.add_input("filters", socket_type="dict", optional=True) - self.set_property("message_type", UNRESOLVED) + self.set_property("message_type", UNRESOLVED) self.set_property("max_iterations", 100) self.set_property("stop_on_time_passage", False) self.set_property("filters", {}) - + self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() message_type = self.get_input_value("message_type") - + if not isinstance(message_type, list): message_type = [message_type] - + # validate message types against MESSAGE keys for mt in message_type: if mt not in MESSAGES: - raise InputValueError(self, "message_type", f"Message type {mt} is not valid") - + raise InputValueError( + self, "message_type", f"Message type {mt} is not valid" + ) + max_iterations = self.get_property("max_iterations") filters = self.get_input_value("filters") stop_on_time_passage = self.get_property("stop_on_time_passage") - + if not filters or filters is UNRESOLVED: filters = {} - - message = scene.last_message_of_type(message_type, max_iterations=max_iterations,stop_on_time_passage=stop_on_time_passage, **filters) - - self.set_output_values({ - "message": message - }) + + message = scene.last_message_of_type( + message_type, + max_iterations=max_iterations, + stop_on_time_passage=stop_on_time_passage, + **filters, + ) + + self.set_output_values({"message": message}) @register("scene/history/ContextHistory") class ContextHistory(Node): """ Compiles history for inclusion in a prompt context. - + Inputs: - + - budget: The budget for the history (number of tokens, defaults to 8192) - + Properties: - + - keep_direcctor_messages: Whether to keep director messages - keep_investigation_messages: Whether to keep investigation messages - keep_reinforcment_messages: Whether to keep reinforcement messages - show_hidden: Whether to show hidden messages - min_dialogue_length: The minimum length of dialogue to keep, this will ensure that there are always N dialogue messages in the history regardless of whether they are covered by summarization. (default 5) - label_chapters: Whether to label chapters in the summarized history - + Outputs: - + - messages: list of messages - compiled: compiled message """ - + class Fields: - budget = PropertyField( name="budget", description="The budget for the history (number of tokens, defaults to 8192)", type="int", default=8192, step=128, - min=0 + min=0, ) - + keep_director_messages = PropertyField( name="keep_director_messages", description="Whether to keep director messages", type="bool", - default=False + default=False, ) - + keep_investigation_messages = PropertyField( name="keep_investigation_messages", description="Whether to keep investigation messages", type="bool", - default=True + default=True, ) - + keep_reinforcement_messages = PropertyField( name="keep_reinforcement_messages", description="Whether to keep reinforcement messages", type="bool", - default=True + default=True, ) - + show_hidden = PropertyField( name="show_hidden", description="Whether to show hidden messages", type="bool", - default=False + default=False, ) - + min_dialogue_length = PropertyField( name="min_dialogue_length", description="The minimum length of dialogue to keep, this will ensure that there are always N dialogue messages in the history regardless of whether they are covered by summarization", @@ -307,20 +315,20 @@ class ContextHistory(Node): min=0, step=1, ) - + label_chapters = PropertyField( name="label_chapters", description="Whether to label chapters in the summarized history", type="bool", - default=False + default=False, ) - + def __init__(self, title="Context History", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("budget", socket_type="int", optional=True) - + self.set_property("budget", 8192) self.set_property("keep_director_messages", False) self.set_property("keep_investigation_messages", False) @@ -328,23 +336,25 @@ class ContextHistory(Node): self.set_property("show_hidden", False) self.set_property("min_dialogue_length", 5) self.set_property("label_chapters", False) - + self.add_output("messages", socket_type="list") self.add_output("compiled", socket_type="str") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() budget = self.require_number_input("budget", types=(int,)) - + keep_director_messages = self.get_property("keep_director_messages") keep_investigation_messages = self.get_property("keep_investigation_messages") keep_reinforcement_messages = self.get_property("keep_reinforcement_messages") show_hidden = self.get_property("show_hidden") - min_dialogue_length = self.require_number_input("min_dialogue_length", types=(int,)) + min_dialogue_length = self.require_number_input( + "min_dialogue_length", types=(int,) + ) label_chapters = self.get_property("label_chapters") - + messages = scene.context_history( - budget = budget, + budget=budget, keep_director=keep_director_messages, keep_context_investigation=keep_investigation_messages, include_reinforcements=keep_reinforcement_messages, @@ -352,55 +362,55 @@ class ContextHistory(Node): assured_dialogue_num=min_dialogue_length, chapter_labels=label_chapters, ) - - self.set_output_values({ - "messages": messages, - "compiled": "\n".join(messages) - }) - - + + self.set_output_values({"messages": messages, "compiled": "\n".join(messages)}) + + @register("scene/history/ActiveCharacterActivity") class ActiveCharacterActivity(Node): """ Returns a list of all active characters sorted by which were last active - + The most recently active character is first in the list. - + Properties: - + - since_time_passage: Only include characters that have acted since the last time passage message - + Outputs: - + - characters: list of characters - none_have_acted: whether no characters have acted """ - + class Fields: since_time_passage = PropertyField( name="since_time_passage", description="Only include characters that have acted since the last time passage message", type="bool", - default=False + default=False, ) - + def __init__(self, title="Character Activity", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("since_time_passage", False) self.add_output("none_have_acted", socket_type="bool") self.add_output("characters", socket_type="list") - - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() - + scene: "Scene" = active_scene.get() + since_time_passage = self.normalized_input_value("since_time_passage") - - activity = await character_activity(scene, since_time_passage=since_time_passage) - - self.set_output_values({ - "characters": activity.characters, - "none_have_acted": activity.none_have_acted - }) \ No newline at end of file + + activity = await character_activity( + scene, since_time_passage=since_time_passage + ) + + self.set_output_values( + { + "characters": activity.characters, + "none_have_acted": activity.none_have_acted, + } + ) diff --git a/src/talemate/game/engine/nodes/layout.py b/src/talemate/game/engine/nodes/layout.py index 62b2d0d9..cfa3c90d 100644 --- a/src/talemate/game/engine/nodes/layout.py +++ b/src/talemate/game/engine/nodes/layout.py @@ -4,62 +4,72 @@ import aiofiles import pydantic import structlog from pathlib import Path -from .core import Graph, UNRESOLVED, NodeBase, Group, Comment, load_extended_components, SaveContext +from .core import ( + Graph, + UNRESOLVED, + NodeBase, + Group, + Comment, + load_extended_components, + SaveContext, +) from .registry import get_node, NodeNotFoundError from .base_types import get_base_type from talemate.game.engine.nodes import SEARCH_PATHS, TALEMATE_ROOT __all__ = [ - 'load_graph', - 'save_graph', - 'export_flat_graph', - 'import_flat_graph', - 'list_node_files', - 'PathInfo', - 'JSONEncoder', - 'normalize_node_filename', - 'load_graph_from_file', + "load_graph", + "save_graph", + "export_flat_graph", + "import_flat_graph", + "list_node_files", + "PathInfo", + "JSONEncoder", + "normalize_node_filename", + "load_graph_from_file", ] log = structlog.get_logger("talemate.game.engine.nodes.layout") + class PathInfo(pydantic.BaseModel): full_path: str relative_path: str search_paths: list[str] = pydantic.Field(default_factory=list) + class JSONEncoder(json.JSONEncoder): """ Will serialize unknowns to strings """ - + def default(self, obj): - if isinstance(obj, UNRESOLVED): return None - + try: return super().default(obj) except TypeError: return str(obj) -def list_node_files(search_paths: list[str] = None, dedupe:bool=True) -> list[str]: + +def list_node_files(search_paths: list[str] = None, dedupe: bool = True) -> list[str]: if search_paths is None: search_paths = SEARCH_PATHS.copy() else: search_paths = search_paths.copy() + SEARCH_PATHS.copy() - + files = [] - + for base_path in search_paths: base_path = Path(base_path) for path in base_path.rglob("*.json"): if path.is_file(): files.append(str(path)) - + # we want semi relative paths, based of the talemate root dir files = [os.path.relpath(file, TALEMATE_ROOT) for file in files] - + # if dedupe is true we want to dedupe from the back (E.g., the first file found is the one we want) # this is for the filename, NOT the relative path if dedupe: @@ -70,19 +80,19 @@ def list_node_files(search_paths: list[str] = None, dedupe:bool=True) -> list[st if filename not in deduped: deduped.append(filename) _files.append(relative_path) - + files = _files - + return files -def export_flat_graph(graph:"Graph") -> dict: - + +def export_flat_graph(graph: "Graph") -> dict: fields = {} for name, field in graph.field_definitions.items(): fields[name] = field.model_dump() for name, field in graph.module_properties.items(): fields[name] = field.model_dump() - + flat = { "nodes": [], "connections": [], @@ -95,13 +105,13 @@ def export_flat_graph(graph:"Graph") -> dict: "title": graph.title, "extends": graph.extends, } - + graph.set_node_references() graph.set_socket_source_references() graph.ensure_connections() - + for node in graph.nodes.values(): - flat_node:dict = { + flat_node: dict = { "id": node.id, "registry": node.registry, "properties": node.properties, @@ -115,29 +125,31 @@ def export_flat_graph(graph:"Graph") -> dict: "inherited": node.inherited, } flat["nodes"].append(flat_node) - + for input in node.inputs: - if not input.source: continue - - flat["connections"].append({ - "from": input.source.full_id, - "to": input.full_id, - }) - + + flat["connections"].append( + { + "from": input.source.full_id, + "to": input.full_id, + } + ) + return flat -def import_flat_graph(flat_data: dict, main_graph:"Graph" = None) -> Graph: + +def import_flat_graph(flat_data: dict, main_graph: "Graph" = None) -> Graph: """ Import a flat graph representation and return a Graph object, handling nested graphs. - + Args: flat_data (dict): Dictionary containing flattened graph data with 'nodes' and 'connections' lists - + Returns: Graph: Reconstructed Graph object with all nodes and connections - + The flat_data format should be: { "nodes": [ @@ -147,7 +159,7 @@ def import_flat_graph(flat_data: dict, main_graph:"Graph" = None) -> Graph: "properties": dict, "x": int, "y": int, - "width": int, + "width": int, "height": int, "parent": str | None # ID of parent node if nested, None if top-level }, @@ -173,30 +185,30 @@ def import_flat_graph(flat_data: dict, main_graph:"Graph" = None) -> Graph: ... ], "registry": str # Registry value for the main graph - + } """ - + # if main_graph is not set get it from the root registry value if main_graph is None: main_graph_cls = get_node(flat_data["registry"]) if not main_graph_cls: main_graph_cls = Graph - + if getattr(main_graph_cls, "__dynamic_imported__", False): main_graph = main_graph_cls(nodes={}, edges={}, groups=[], comments=[]) else: main_graph = main_graph_cls() - + main_graph.properties = flat_data.get("properties", {}) main_graph.extends = flat_data.get("extends", None) - + def create_mode_module(node_data: dict) -> NodeBase: """Helper function to create a node instance from node data""" node_cls = get_node(node_data["registry"]) if not node_cls: raise ValueError(f"Unknown node type: {node_data['registry']}") - + node = node_cls( id=node_data["id"], x=node_data["x"], @@ -206,68 +218,70 @@ def import_flat_graph(flat_data: dict, main_graph:"Graph" = None) -> Graph: title=node_data["title"], collapsed=node_data.get("collapsed", False), ) - - + # this needs to happen after the node is created # so that inputs and outputs are created node.properties = node_data["properties"] - + return node def add_connections(graph: Graph, connections: list, node_map: dict): """Helper function to add connections to a graph""" graph.edges = {} - + for connection in connections: if connection["from"] not in graph.edges: graph.edges[connection["from"]] = [] - + if connection["to"] not in graph.edges[connection["from"]]: graph.edges[connection["from"]].append(connection["to"]) node_map = {} # Maps node IDs to node instances - + # First pass: Create all nodes and build hierarchy for node_data in flat_data["nodes"]: node = create_mode_module(node_data) node_map[node.id] = node - + # Add to parent if nested, otherwise to main graph parent_id = node_data.get("parent") if parent_id: if parent_id not in node_map: - raise ValueError(f"Parent node {parent_id} not found for node {node.id}") + raise ValueError( + f"Parent node {parent_id} not found for node {node.id}" + ) parent_node = node_map[parent_id] if not hasattr(parent_node, "nodes"): raise ValueError(f"Parent node {parent_id} cannot contain other nodes") parent_node.add_node(node) else: main_graph.add_node(node) - + # Second pass: Create all connections add_connections(main_graph, flat_data["connections"], node_map) - + # Third pass: Rebuild groups for group_data in flat_data["groups"]: group = Group(**group_data) main_graph.groups.append(group) - + # Fourth pass: Rebuild comments for comment_data in flat_data["comments"]: comment = Comment(**comment_data) - main_graph.comments.append(comment) - + main_graph.comments.append(comment) + if main_graph.extends: graph_data = main_graph.model_dump() load_extended_components(main_graph.extends, graph_data) main_graph = main_graph.__class__(**graph_data) - + # Initialize the graph return main_graph.reinitialize() -def load_graph(file_name: str, search_paths: list[str] = None, graph_cls = None) -> tuple[Graph, PathInfo]: - - + +def load_graph( + file_name: str, search_paths: list[str] = None, graph_cls=None +) -> tuple[Graph, PathInfo]: if search_paths is None: search_paths = SEARCH_PATHS.copy() else: @@ -283,31 +297,34 @@ def load_graph(file_name: str, search_paths: list[str] = None, graph_cls = None) search_paths = [path] break - # Convert all search paths to Path objects search_paths = [Path(path) for path in search_paths] - + for base_path in search_paths: # Look for the file in current directory file_path = base_path / file_name if file_path.exists(): return load_graph_from_file(file_path, graph_cls, search_paths) - + # Search recursively through subdirectories for path in base_path.rglob(file_name): if path.is_file(): return load_graph_from_file(path, graph_cls, search_paths) - raise FileNotFoundError(f"Could not find {file_name} in any of the search paths: {search_paths}") + raise FileNotFoundError( + f"Could not find {file_name} in any of the search paths: {search_paths}" + ) -def load_graph_from_file(file_path: str, graph_cls = None, search_paths: list[str] = None) -> tuple[Graph, PathInfo]: - - with open(file_path, 'r') as file: + +def load_graph_from_file( + file_path: str, graph_cls=None, search_paths: list[str] = None +) -> tuple[Graph, PathInfo]: + with open(file_path, "r") as file: data = json.load(file) - + if data.get("extends"): load_extended_components(data["extends"], data) - + if not graph_cls: try: graph_cls = get_node(data["registry"]) @@ -317,20 +334,21 @@ def load_graph_from_file(file_path: str, graph_cls = None, search_paths: list[st graph_cls = get_base_type(data["base_type"]) if not graph_cls: graph_cls = Graph - + remove_invalid_nodes(data["nodes"]) - + return graph_cls(**data).reinitialize(), PathInfo( full_path=str(file_path), relative_path=os.path.relpath(file_path, TALEMATE_ROOT), - search_paths=[str(path) for path in search_paths] if search_paths else [] + search_paths=[str(path) for path in search_paths] if search_paths else [], ) -def remove_invalid_nodes(nodes:dict): + +def remove_invalid_nodes(nodes: dict): """ Remove nodes that have no properties """ - + for node_id, node_data in list(nodes.items()): registry_name = node_data["registry"] try: @@ -339,11 +357,12 @@ def remove_invalid_nodes(nodes:dict): log.error("Removing UNKNOWN node", registry_name=registry_name) del nodes[node_id] + async def save_graph(graph: Graph, file_path: str): with SaveContext(): - async with aiofiles.open(file_path, 'w') as file: + async with aiofiles.open(file_path, "w") as file: await file.write(json.dumps(graph.model_dump(), indent=2, cls=JSONEncoder)) - - -def normalize_node_filename(node_name:str) -> str: - return node_name.lower().replace(" ", "-") + ".json" \ No newline at end of file + + +def normalize_node_filename(node_name: str) -> str: + return node_name.lower().replace(" ", "-") + ".json" diff --git a/src/talemate/game/engine/nodes/load_definitions.py b/src/talemate/game/engine/nodes/load_definitions.py index 7cc7ffcc..3d778992 100644 --- a/src/talemate/game/engine/nodes/load_definitions.py +++ b/src/talemate/game/engine/nodes/load_definitions.py @@ -1,19 +1,19 @@ -import talemate.game.engine.nodes.core -import talemate.game.engine.nodes.command -import talemate.game.engine.nodes.logic -import talemate.game.engine.nodes.state -import talemate.game.engine.nodes.scene -import talemate.game.engine.nodes.scene_intent -import talemate.game.engine.nodes.world_state -import talemate.game.engine.nodes.run -import talemate.game.engine.nodes.api -import talemate.game.engine.nodes.data -import talemate.game.engine.nodes.string -import talemate.game.engine.nodes.number -import talemate.game.engine.nodes.raise_errors -import talemate.game.engine.nodes.event -import talemate.game.engine.nodes.focal -import talemate.game.engine.nodes.util -import talemate.game.engine.nodes.history -import talemate.game.engine.nodes.prompt -import talemate.game.engine.nodes.packaging \ No newline at end of file +import talemate.game.engine.nodes.core # noqa: F401 +import talemate.game.engine.nodes.command # noqa: F401 +import talemate.game.engine.nodes.logic # noqa: F401 +import talemate.game.engine.nodes.state # noqa: F401 +import talemate.game.engine.nodes.scene # noqa: F401 +import talemate.game.engine.nodes.scene_intent # noqa: F401 +import talemate.game.engine.nodes.world_state # noqa: F401 +import talemate.game.engine.nodes.run # noqa: F401 +import talemate.game.engine.nodes.api # noqa: F401 +import talemate.game.engine.nodes.data # noqa: F401 +import talemate.game.engine.nodes.string # noqa: F401 +import talemate.game.engine.nodes.number # noqa: F401 +import talemate.game.engine.nodes.raise_errors # noqa: F401 +import talemate.game.engine.nodes.event # noqa: F401 +import talemate.game.engine.nodes.focal # noqa: F401 +import talemate.game.engine.nodes.util # noqa: F401 +import talemate.game.engine.nodes.history # noqa: F401 +import talemate.game.engine.nodes.prompt # noqa: F401 +import talemate.game.engine.nodes.packaging # noqa: F401 diff --git a/src/talemate/game/engine/nodes/logic.py b/src/talemate/game/engine/nodes/logic.py index e8b2c7a2..0e1b2c3f 100644 --- a/src/talemate/game/engine/nodes/logic.py +++ b/src/talemate/game/engine/nodes/logic.py @@ -25,20 +25,21 @@ __all__ = [ log = structlog.get_logger("talemate.game.engine.nodes.core.logic") + def is_truthy(value): return value is not None and value is not False and value is not UNRESOLVED + class LogicalRouter(Node): - """ Base node class for logical routers """ - - _op:ClassVar[str] = "and" - + + _op: ClassVar[str] = "and" + def __init__(self, title="OR Router", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("a", socket_type="bool", group="flags") self.add_input("b", socket_type="bool", group="flags") @@ -47,7 +48,7 @@ class LogicalRouter(Node): self.add_input("value", optional=True) self.add_output("yes") self.add_output("no") - + async def run(self, state: GraphState): # Get all flag values a = self.get_input_value("a") @@ -55,15 +56,15 @@ class LogicalRouter(Node): c = self.get_input_value("c") d = self.get_input_value("d") value = self.get_input_value("value") - + a_connected = self.get_input_socket("a").source is not None b_connected = self.get_input_socket("b").source is not None c_connected = self.get_input_socket("c").source is not None d_connected = self.get_input_socket("d").source is not None - + # Initialize flags list with required flags active_flags = [] - + if a_connected: active_flags.append(Socket.as_bool(a)) if b_connected: @@ -72,7 +73,7 @@ class LogicalRouter(Node): active_flags.append(Socket.as_bool(c)) if d_connected: active_flags.append(Socket.as_bool(d)) - + # If no valid flags are provided, treat as False if not active_flags: result = False @@ -83,243 +84,256 @@ class LogicalRouter(Node): result = all(active_flags) else: raise ValueError(f"Unknown operation: {self._op}") - + # Set output deactivation self.outputs[0].deactivated = result is False # yes - self.outputs[1].deactivated = result is True # no - - + self.outputs[1].deactivated = result is True # no + # return value should fall back to result if not provided if value is UNRESOLVED: - value = True - + value = True + # Set output values - self.set_output_values({ - "yes": value if result else UNRESOLVED, - "no": value if not result else UNRESOLVED, - }) + self.set_output_values( + { + "yes": value if result else UNRESOLVED, + "no": value if not result else UNRESOLVED, + } + ) # expand socket sources for debugging if state.verbosity == NodeVerbosity.VERBOSE: - log.debug(f"LogicalRouter {self.title} result: {result}", flags=active_flags, result=result, input_values=self.get_input_values()) - + log.debug( + f"LogicalRouter {self.title} result: {result}", + flags=active_flags, + result=result, + input_values=self.get_input_values(), + ) + for socket in self.inputs: if socket.source: - log.debug(f"LogicalRouter {self.title} Input {socket.name} source: {socket.source.node.title}.{socket.source.name} value: {socket.value} ! {socket.source.value}") + log.debug( + f"LogicalRouter {self.title} Input {socket.name} source: {socket.source.node.title}.{socket.source.name} value: {socket.value} ! {socket.source.value}" + ) else: - log.debug(f"LogicalRouter {self.title} Input {socket.name} source: NOT CONNECTED") + log.debug( + f"LogicalRouter {self.title} Input {socket.name} source: NOT CONNECTED" + ) + @register("core/ORRouter") class ORRouter(LogicalRouter): """ Route a value based on OR logic where any of a - d is truthy (if connected) - + Truthy values are considered as True, False and None are considered as False - + If a value is provided, it will be returned if the result is True If no value is provided, True will be returned on the output activated through the result, the other output will be deactivated - + Inputs: - + - a: flag A - b: flag B - c: flag C - d: flag D - + Outputs: - + - yes: if the result is True - no: if the result is False """ - + _op = "or" def __init__(self, title="OR Router", **kwargs): super().__init__(title=title, **kwargs) + @register("core/ANDRouter") class ANDRouter(LogicalRouter): """ Route a value based on AND logic where all of a - d are truthy (if connected) - + Truthy values are considered as True, False and None are considered as False - + If a value is provided, it will be returned if the result is True If no value is provided, True will be returned on the output activated through the result, the other output will be deactivated - + Inputs: - + - a: flag A - b: flag B - c: flag C - d: flag D - + Outputs: - + - yes: if the result is True - no: if the result is False """ - + _op = "and" - + def __init__(self, title="AND Router", **kwargs): super().__init__(title=title, **kwargs) + @register("core/Invert") class Invert(Node): """ Takes a boolean input and inverts it - + Inputs: - + - value: boolean value - + Outputs: - + - value: inverted boolean value """ - + def __init__(self, title="Invert", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value", socket_type="bool") self.add_output("value", socket_type="bool") - + async def run(self, state: GraphState): value = self.get_input_value("value") result = not Socket.as_bool(value) - self.set_output_values({ - "value": result - }) + self.set_output_values({"value": result}) + @register("core/Switch") class Switch(Node): """ Checks if the input value is not None or False - + If the value is truthy, the yes output is activated, otherwise the no output is activated - + Inputs: - + - value: value to check - + Properties: - + - pass_through: if True, the value will be passed through to the output, otherwise True will be passed through - + Outputs: - + - yes: if the value is truthy - no: if the value is not truthy """ - + class Fields: - pass_through = PropertyField( name="pass_through", type="bool", default=True, - description="If True, the value will be passed through to the output, otherwise True will be passed through" + description="If True, the value will be passed through to the output, otherwise True will be passed through", ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( icon="F0641", ) - + def __init__(self, title="Switch", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value") - + self.set_property("pass_through", True) - + self.add_output("yes") self.add_output("no") - + async def run(self, state: GraphState): value = self.get_input_value("value") pass_through = self.get_property("pass_through") - + result = is_truthy(value) - + if not pass_through: value = True - - self.set_output_values({ - "yes": value if result else UNRESOLVED, - "no": value if not result else UNRESOLVED, - }) - + + self.set_output_values( + { + "yes": value if result else UNRESOLVED, + "no": value if not result else UNRESOLVED, + } + ) + self.get_output_socket("yes").deactivated = not result self.get_output_socket("no").deactivated = result + @register("core/RSwitch") class RSwitch(Node): """ Checks if the a value is truthy - + If the value is truthy, the yes input is routed to the output, otherwise the no input is routed to the output - + Inputs: - + - check: value to check - yes: value to return if the check value is truthy - no: value to return if the check value is not truthy - + Outputs: - + - value: the value to return """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( icon="F0641", ) - + def __init__(self, title="RSwitch", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("check", optional=True) self.add_input("yes", optional=True) self.add_input("no", optional=True) self.add_output("value") - + async def run(self, state: GraphState): check = self.get_input_value("check") yes = self.get_input_value("yes") no = self.get_input_value("no") - + check = is_truthy(check) - + result = yes if check else no - - self.set_output_values({ - "value": result - }) + + self.set_output_values({"value": result}) + @register("core/RSwitchAdvanced") class RSwitchAdvanced(Node): """ Checks if the a value is truthy - + If the value is truthy, the yes input is routed to yes output and the no output is deactivated, otherwise the no input is routed to the no output and the yes output is deactivated - + Inputs: - + - check: value to check - yes: value to return if the check value is truthy - no: value to return if the check value is not truthy - + Outputs: - + - yes: the value to return if the check value is truthy - no: the value to return if the check value is not truthy """ @@ -330,121 +344,124 @@ class RSwitchAdvanced(Node): return NodeStyle( icon="F0641", ) - + def __init__(self, title="RSwitch Advanced", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("check", optional=True) self.add_input("yes", optional=True) self.add_input("no", optional=True) self.add_output("yes") self.add_output("no") - + async def run(self, state: GraphState): check = self.get_input_value("check") yes = self.get_input_value("yes") no = self.get_input_value("no") - + check = is_truthy(check) - + result = yes if check else no - - self.set_output_values({ - "yes": result if check else UNRESOLVED, - "no": result if not check else UNRESOLVED - }) - + + self.set_output_values( + { + "yes": result if check else UNRESOLVED, + "no": result if not check else UNRESOLVED, + } + ) + + @register("core/Case") class Case(Node): """ Route a value based on attribute value check (exact match) like a switch / case statement. - + Inputs: - + - value: value to check - + Properties: - + - attribute_name: the attribute name to check for the value. If not provided, the value itself will be used. - case_a: the value to compare to for case A - case_b: the value to compare to for case B - case_c: the value to compare to for case C - case_d: the value to compare to for case D - + Outputs: - + - a: if the value matches case A - b: if the value matches case B - c: if the value matches case C - d: if the value matches case D """ - + class Fields: attribute_name = PropertyField( name="attribute_name", type="str", default="", - description="The attribute name to check for the value" + description="The attribute name to check for the value", ) - + case_a = PropertyField( name="case_a", type="str", default="", - description="The value to compare to for case A" + description="The value to compare to for case A", ) - + case_b = PropertyField( name="case_b", type="str", default="", - description="The value to compare to for case B" + description="The value to compare to for case B", ) - + case_c = PropertyField( name="case_c", type="str", default="", - description="The value to compare to for case C" + description="The value to compare to for case C", ) - + case_d = PropertyField( name="case_d", type="str", default="", - description="The value to compare to for case D" + description="The value to compare to for case D", ) - + def __init__(self, title="Case", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value") - + self.set_property("attribute_name", "") - + self.set_property("case_a", "") self.set_property("case_b", "") self.set_property("case_c", "") self.set_property("case_d", "") - + self.add_output("a") self.add_output("b") self.add_output("c") self.add_output("d") self.add_output("none") - + async def run(self, state: GraphState): value = self.get_input_value("value") attribute_name = self.get_property("attribute_name") - + if is_truthy(attribute_name) and attribute_name.strip(): compare_to = getattr(value, attribute_name) else: compare_to = str(value) - + case_a = self.get_property("case_a") case_b = self.get_property("case_b") case_c = self.get_property("case_c") @@ -460,45 +477,49 @@ class Case(Node): self.set_output_values({"d": value}) else: self.set_output_values({"none": value}) - + if state.verbosity >= NodeVerbosity.NORMAL: - log.debug(f"Case {self.title} value: {value} attribute_name: {attribute_name} cases: {case_a}, {case_b}, {case_c}, {case_d}", compare_to=compare_to) - -@register("core/Coallesce") + log.debug( + f"Case {self.title} value: {value} attribute_name: {attribute_name} cases: {case_a}, {case_b}, {case_c}, {case_d}", + compare_to=compare_to, + ) + + +@register("core/Coallesce") class Coallesce(Node): """ Takes a list of values and returns the first non-UNRESOLVED value - + Inputs: - + - a: value A - b: value B - c: value C - d: value D - + Outputs: - + - value: the first non-UNRESOLVED value """ - + def __init__(self, title="Coallesce", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("a", optional=True) self.add_input("b", optional=True) self.add_input("c", optional=True) self.add_input("d", optional=True) self.add_output("value") - + async def run(self, state: GraphState): a = self.get_input_value("a") b = self.get_input_value("b") c = self.get_input_value("c") d = self.get_input_value("d") - + result = UNRESOLVED - + if is_truthy(a): result = a elif is_truthy(b): @@ -507,165 +528,155 @@ class Coallesce(Node): result = c elif is_truthy(d): result = d - - self.set_output_values({ - "value": result - }) - + + self.set_output_values({"value": result}) + + @register("core/MakeBool") class MakeBool(Node): """ Creates a boolean value - + Properties: - + - value: boolean value - + Outputs: - + - value: boolean value """ - + class Fields: value = PropertyField( - name="value", - type="bool", - default=True, - description="The boolean value" + name="value", type="bool", default=True, description="The boolean value" ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( auto_title="{value}", ) - + def __init__(self, title="Make Bool", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("value", True) self.add_output("value", socket_type="bool") - + async def run(self, state: GraphState): value = self.get_input_value("value") - self.set_output_values({ - "value": value - }) - + self.set_output_values({"value": value}) + + @register("core/AsBool") class AsBool(Node): """ Converts a value to a boolean - + This specfically handles UNRESOLVED by casting it to a default value - + Inputs: - + - value: value to convert - + Properties: - + - default: the default value to return if the value is UNRESOLVED - + Outputs: - + - value: boolean value """ - + class Fields: default = PropertyField( name="default", type="bool", default=False, - description="The default value to return if the value is UNRESOLVED" + description="The default value to return if the value is UNRESOLVED", ) - + def __init__(self, title="As Bool", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value", optional=True) self.set_property("default", False) self.add_output("value", socket_type="bool") - + async def run(self, state: GraphState): value = self.get_input_value("value") default = self.get_property("default") - + if value is UNRESOLVED: value = default - + if not isinstance(value, bool): try: value = bool(value) except Exception as e: raise InputValueError(f"Failed to convert value to bool: {e}") - - self.set_output_values({ - "value": value - }) - - + + self.set_output_values({"value": value}) + + @register("core/ApplyDefault") class ApplyDefault(Node): """ Applies a default value if the input value is UNRESOLVED - + Inputs: - + - value: value to apply the default to - default: the default value to apply - + Properties: - apply_on_unresolved: if True, the default will be applied if the value is UNRESOLVED - apply_on_none: if True, the default will be applied if the value is None - - Outputs: - + + Outputs: + - value: the value with the default applied """ - + class Fields: apply_on_none = PropertyField( name="apply_on_none", type="bool", default=False, - description="If True, the default will be applied if the value is None" + description="If True, the default will be applied if the value is None", ) - + apply_on_unresolved = PropertyField( name="apply_on_unresolved", type="bool", default=True, - description="If True, the default will be applied if the value is UNRESOLVED" + description="If True, the default will be applied if the value is UNRESOLVED", ) - - + def __init__(self, title="Apply Default", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value", optional=True) self.add_input("default") self.set_property("apply_on_none", False) self.set_property("apply_on_unresolved", True) self.add_output("value") - + async def run(self, state: GraphState): value = self.get_input_value("value") default = self.require_input("default") apply_on_none = self.get_property("apply_on_none") apply_on_unresolved = self.get_property("apply_on_unresolved") - + if apply_on_unresolved and value is UNRESOLVED: value = default - + if apply_on_none and value is None: value = default - - self.set_output_values({ - "value": value - }) + + self.set_output_values({"value": value}) diff --git a/src/talemate/game/engine/nodes/number.py b/src/talemate/game/engine/nodes/number.py index 8a114212..fb3efc39 100644 --- a/src/talemate/game/engine/nodes/number.py +++ b/src/talemate/game/engine/nodes/number.py @@ -1,5 +1,4 @@ import random -import math import statistics import structlog from .core import Node, GraphState, UNRESOLVED, PropertyField, InputValueError @@ -7,15 +6,14 @@ from .registry import register log = structlog.get_logger("talemate.game.engine.nodes.number") + class NumberNode(Node): - - def normalized_number_input(self, name:str, types:tuple=(int, float)): - + def normalized_number_input(self, name: str, types: tuple = (int, float)): value = self.get_input_value(name) - + if value is UNRESOLVED: return UNRESOLVED - + try: if float in types: value = float(value) @@ -23,159 +21,153 @@ class NumberNode(Node): value = int(value) except ValueError: raise InputValueError(self, name, "Invalid number") - + return value + @register("data/number/Make") class MakeNumber(NumberNode): """Creates a number with a specified value - + Creates either an integer or floating point number with the specified value. - + Properties: - + - number_type: Type of number to create ("int" or "float") - value: The numeric value to create - + Outputs: - + - value: The created number value """ - + class Fields: number_type = PropertyField( name="number_type", description="Type of number to create", type="str", default="float", - choices=["int", "float"] + choices=["int", "float"], ) - + value = PropertyField( name="value", description="The numeric value to create", type="number", - default=0 + default=0, ) - + def setup(self): self.set_property("value", 0) self.set_property("number_type", "float") self.add_output("value", socket_type="any") # Can be int or float - + async def run(self, state: GraphState): - number_type = self.get_property("number_type") - + if number_type == "int": types = (int,) else: types = (float,) - + value = self.normalized_number_input("value", types) self.set_output_values({"value": value}) + @register("data/number/AsNumber") class AsNumber(NumberNode): """Converts a value to a number - + Converts a value to a number, handling both string and numeric inputs. - + Inputs: - + - value: The value to convert to a number - + Outputs: - + - value: The converted number value """ - + class Fields: number_type = PropertyField( name="number_type", description="Type of number to create", type="str", default="float", - choices=["int", "float"] + choices=["int", "float"], ) - + def setup(self): self.add_input("value", socket_type="any") self.set_property("number_type", "int") self.add_output("value", socket_type="int,float") - + async def run(self, state: GraphState): - if self.get_property("number_type") == "int": valid_types = (int,) else: valid_types = (float,) - + value = self.normalized_number_input("value", valid_types) self.set_output_values({"value": value}) + @register("data/number/BasicArithmetic") class BasicArithmetic(NumberNode): """Performs basic arithmetic operations - - Performs one of the following operations on two input values: add, subtract, + + Performs one of the following operations on two input values: add, subtract, multiply, divide, power, or modulo. - + Inputs: - + - a: First operand (number) - b: Second operand (number) - + Properties: - + - operation: Arithmetic operation to perform (add, subtract, multiply, divide, power, modulo) - + Outputs: - + - result: Result of the arithmetic operation """ - + class Fields: operation = PropertyField( name="operation", description="Arithmetic operation to perform", type="str", default="add", - choices=["add", "subtract", "multiply", "divide", "power", "modulo"] + choices=["add", "subtract", "multiply", "divide", "power", "modulo"], ) - + a = PropertyField( - name="a", - description="First value to compare", - type="number", - default=0 + name="a", description="First value to compare", type="number", default=0 ) b = PropertyField( - name="b", - description="Second value to compare", - type="number", - default=0 + name="b", description="Second value to compare", type="number", default=0 ) - - + def setup(self): self.add_input("a", socket_type="int,float") self.add_input("b", socket_type="int,float") self.add_output("result", socket_type="int,float") - + self.set_property("operation", "add") self.set_property("a", 0) self.set_property("b", 0) - + async def run(self, state: GraphState): a = self.normalized_number_input("a") b = self.normalized_number_input("b") - + if a is UNRESOLVED or b is UNRESOLVED: return - + operation = self.get_property("operation") - + try: if operation == "add": result = a + b @@ -188,86 +180,87 @@ class BasicArithmetic(NumberNode): raise InputValueError(self, "b", "Division by zero") result = a / b elif operation == "power": - result = a ** b + result = a**b elif operation == "modulo": if b == 0: raise InputValueError(self, "b", "Modulo by zero") result = a % b - + self.set_output_values({"result": result}) - + except Exception as e: raise InputValueError(self, "a", f"Calculation error: {str(e)}") + @register("data/number/Compare") class Compare(NumberNode): """Compares two numbers - + Performs comparison operations between two numeric values with optional tolerance for floating point comparisons. - + Inputs: - + - a: First value to compare (number) - b: Second value to compare (number) - + Properties: - + - operation: Comparison operation to perform (equals, not_equals, greater_than, less_than, greater_equal, less_equal) - tolerance: Tolerance for floating point equality comparison - + Outputs: - + - result: Boolean result of the comparison """ - + class Fields: operation = PropertyField( name="operation", description="Comparison operation to perform", type="str", default="equals", - choices=["equals", "not_equals", "greater_than", "less_than", - "greater_equal", "less_equal"] + choices=[ + "equals", + "not_equals", + "greater_than", + "less_than", + "greater_equal", + "less_equal", + ], ) tolerance = PropertyField( name="tolerance", description="Tolerance for floating point comparison", type="float", - default=0.0001 + default=0.0001, ) a = PropertyField( - name="a", - description="First value to compare", - type="number", - default=0 + name="a", description="First value to compare", type="number", default=0 ) b = PropertyField( - name="b", - description="Second value to compare", - type="number", - default=0 + name="b", description="Second value to compare", type="number", default=0 ) - + def setup(self): self.add_input("a", socket_type="int,float") self.add_input("b", socket_type="int,float") self.add_output("result", socket_type="bool") - + self.set_property("operation", "equals") self.set_property("tolerance", 0.0001) # For floating point comparison self.set_property("a", 0) self.set_property("b", 0) - + async def run(self, state: GraphState): a = self.normalized_number_input("a") b = self.normalized_number_input("b") operation = self.get_property("operation") tolerance = self.get_property("tolerance") - + if a is UNRESOLVED or b is UNRESOLVED: return - + if operation == "equals": result = abs(a - b) <= tolerance elif operation == "not_equals": @@ -280,144 +273,143 @@ class Compare(NumberNode): result = a >= b elif operation == "less_equal": result = a <= b - + self.set_output_values({"result": result}) @register("data/number/MinMax") class MinMax(NumberNode): """Finds minimum or maximum in a list of numbers - + Takes a list of numbers and finds either the minimum or maximum value, returning both the value and its index in the list. - + Inputs: - + - numbers: List of numbers to analyze - + Properties: - + - operation: Operation to perform (min or max) - + Outputs: - + - result: The minimum or maximum value - index: The index position of the minimum or maximum value in the list """ - + class Fields: operation = PropertyField( name="operation", description="Operation to perform", type="str", default="min", - choices=["min", "max"] + choices=["min", "max"], ) - + def setup(self): self.add_input("numbers", socket_type="list") self.add_output("result", socket_type="int,float") self.add_output("index", socket_type="int") - + self.set_property("operation", "min") - + async def run(self, state: GraphState): numbers = self.get_input_value("numbers") operation = self.get_property("operation") - + if not numbers: raise InputValueError(self, "numbers", "Empty list provided") - + if not all(isinstance(n, (int, float)) for n in numbers): raise InputValueError(self, "numbers", "All items must be numbers") - + if operation == "min": result = min(numbers) index = numbers.index(result) elif operation == "max": result = max(numbers) index = numbers.index(result) - - self.set_output_values({ - "result": result, - "index": index - }) + + self.set_output_values({"result": result, "index": index}) + @register("data/number/Sum") class Sum(NumberNode): """Sums a list of numbers - + Calculates the sum of all values in a list of numbers. - + Inputs: - + - numbers: List of numbers to sum - + Outputs: - + - result: The sum of all numbers in the list """ - + def setup(self): self.add_input("numbers", socket_type="list") self.add_output("result", socket_type="int,float") - + self.set_property("numbers", []) - + async def run(self, state: GraphState): numbers = self.get_input_value("numbers") - + if not all(isinstance(n, (int, float)) for n in numbers): raise InputValueError(self, "numbers", "All items must be numbers") - + result = sum(numbers) self.set_output_values({"result": result}) + @register("data/number/Average") class Average(NumberNode): """Calculates average of a list of numbers - + Calculates one of three types of average (mean, median, or mode) from a list of numeric values. - + Inputs: - + - numbers: List of numbers to calculate average from - + Properties: - + - method: Type of average to calculate (mean, median, mode) - + Outputs: - + - result: The calculated average value """ - + class Fields: method = PropertyField( name="method", description="Type of average to calculate", type="str", default="mean", - choices=["mean", "median", "mode"] + choices=["mean", "median", "mode"], ) - + def setup(self): self.add_input("numbers", socket_type="list") self.add_output("result", socket_type="int,float") - + self.set_property("method", "mean") - + async def run(self, state: GraphState): numbers = self.get_input_value("numbers") method = self.get_property("method") - + if not numbers: raise InputValueError(self, "numbers", "Empty list provided") - + if not all(isinstance(n, (int, float)) for n in numbers): raise InputValueError(self, "numbers", "All items must be numbers") - + try: if method == "mean": result = statistics.mean(numbers) @@ -429,190 +421,190 @@ class Average(NumberNode): except statistics.StatisticsError: # Handle multimodal or no mode case result = None - + self.set_output_values({"result": result}) - + except Exception as e: raise InputValueError(self, "numbers", f"Calculation error: {str(e)}") + @register("data/number/Random") class Random(NumberNode): """Generates random numbers - + Generates random numbers using various distributions (uniform, integer, normal) or selects a random item from a list of choices. - + Inputs: - + - min: Minimum value for uniform/integer distribution (optional) - max: Maximum value for uniform/integer distribution (optional) - mean: Mean value for normal distribution (optional) - std_dev: Standard deviation for normal distribution (optional) - choices: List to select a random item from (optional) - + Properties: - + - method: Type of random number to generate (uniform, integer, normal, choice) - min: Default minimum value - max: Default maximum value - mean: Default mean value - std_dev: Default standard deviation value - + Outputs: - + - result: The generated random number or selected item """ - + class Fields: method = PropertyField( name="method", description="Type of random number to generate", type="str", default="uniform", - choices=["uniform", "integer", "normal", "choice"] + choices=["uniform", "integer", "normal", "choice"], ) min = PropertyField( name="min", description="Minimum value for uniform/integer distribution", type="float", - default=0.0 + default=0.0, ) max = PropertyField( name="max", description="Maximum value for uniform/integer distribution", type="float", - default=1.0 + default=1.0, ) mean = PropertyField( name="mean", description="Mean value for normal distribution", type="float", - default=0.0 + default=0.0, ) std_dev = PropertyField( name="std_dev", description="Standard deviation for normal distribution", type="float", - default=1.0 + default=1.0, ) - + def setup(self): self.add_input("min", socket_type="int,float", optional=True) self.add_input("max", socket_type="int,float", optional=True) self.add_input("mean", socket_type="int,float", optional=True) self.add_input("std_dev", socket_type="int,float", optional=True) self.add_input("choices", socket_type="list", optional=True) - + self.add_output("result", socket_type="int,float") - + self.set_property("method", "uniform") self.set_property("min", 0.0) self.set_property("max", 1.0) self.set_property("mean", 0.0) self.set_property("std_dev", 1.0) - + async def run(self, state: GraphState): method = self.get_property("method") - + if method == "uniform": min_val = self.normalized_number_input("min") max_val = self.normalized_number_input("max") - + if min_val is UNRESOLVED or max_val is UNRESOLVED: return - + result = random.uniform(min_val, max_val) - + elif method == "integer": min_val = int(self.normalized_number_input("min")) max_val = int(self.normalized_number_input("max")) - + if min_val is UNRESOLVED or max_val is UNRESOLVED: return - + result = random.randint(min_val, max_val) - + elif method == "normal": mean = self.normalized_number_input("mean") std_dev = self.normalized_number_input("std_dev") - + if mean is UNRESOLVED or std_dev is UNRESOLVED: return - + if std_dev <= 0: - raise InputValueError(self, "std_dev", "Standard deviation must be positive") - + raise InputValueError( + self, "std_dev", "Standard deviation must be positive" + ) + result = random.normalvariate(mean, std_dev) - + elif method == "choice": choices = self.get_input_value("choices") - + if not choices: raise InputValueError(self, "choices", "Empty list provided") - + result = random.choice(choices) - + self.set_output_values({"result": result}) + @register("data/number/Clamp") class Clamp(NumberNode): """Constrains a number within a specified range - + Takes a value and ensures it falls within a specific minimum and maximum range. If the value is below the minimum, it returns the minimum. If it's above the maximum, it returns the maximum. - + Inputs: - + - value: The number to constrain - min: Minimum allowed value - max: Maximum allowed value - + Outputs: - + - result: The value constrained to the specified range """ - + class Fields: value = PropertyField( name="value", description="The number to constrain", type="number", - default=0 + default=0, ) min = PropertyField( - name="min", - description="Minimum allowed value", - type="number", - default=0 + name="min", description="Minimum allowed value", type="number", default=0 ) max = PropertyField( - name="max", - description="Maximum allowed value", - type="number", - default=1 + name="max", description="Maximum allowed value", type="number", default=1 ) - + def setup(self): self.add_input("value", socket_type="int,float") self.add_input("min", socket_type="int,float") self.add_input("max", socket_type="int,float") self.add_output("result", socket_type="int,float") - + self.set_property("value", 0) self.set_property("min", 0) self.set_property("max", 1) - + async def run(self, state: GraphState): value = self.normalized_number_input("value") min_val = self.normalized_number_input("min") max_val = self.normalized_number_input("max") - + if value is UNRESOLVED or min_val is UNRESOLVED or max_val is UNRESOLVED: return - + if min_val > max_val: - raise InputValueError(self, "min", "Minimum value cannot be greater than maximum") - + raise InputValueError( + self, "min", "Minimum value cannot be greater than maximum" + ) + result = max(min_val, min(value, max_val)) - self.set_output_values({"result": result}) \ No newline at end of file + self.set_output_values({"result": result}) diff --git a/src/talemate/game/engine/nodes/packaging.py b/src/talemate/game/engine/nodes/packaging.py index c936395d..c82f5ecc 100644 --- a/src/talemate/game/engine/nodes/packaging.py +++ b/src/talemate/game/engine/nodes/packaging.py @@ -13,10 +13,7 @@ from .core import ( Node, Graph, register, - GraphState, - GraphContext, UNRESOLVED, - InputValueError, PropertyField, NodeStyle, TYPE_CHOICES, @@ -26,14 +23,10 @@ from .registry import get_node, get_nodes_by_base_type from .scene import SceneLoop -from .base_types import base_node_type -from talemate.context import interaction - if TYPE_CHECKING: from talemate.tale_mate import Scene __all__ = [ - "PackageInfo", "PromoteConfig", "initialize_scene_package_info", "get_scene_package_info", @@ -49,10 +42,12 @@ __all__ = [ log = structlog.get_logger("talemate.game.engine.nodes.packaging") - -TYPE_CHOICES.extend([ - "node_module", -]) + +TYPE_CHOICES.extend( + [ + "node_module", + ] +) SCENE_PACKAGE_INFO_FILENAME = "modules.json" @@ -60,6 +55,7 @@ SCENE_PACKAGE_INFO_FILENAME = "modules.json" # MODELS # ------------------------------------------------------------------------------------------------ + class PackageProperty(pydantic.BaseModel): module: str name: str @@ -71,6 +67,7 @@ class PackageProperty(pydantic.BaseModel): required: bool = pydantic.Field(default=False) choices: list[str] | None = None + class PackageData(pydantic.BaseModel): name: str author: str @@ -78,98 +75,107 @@ class PackageData(pydantic.BaseModel): installable: bool registry: str status: Literal["installed", "not_installed"] = "not_installed" - + errors: list[str] = pydantic.Field(default_factory=list) - - package_properties: dict[str, PackageProperty] = pydantic.Field(default_factory=dict) + + package_properties: dict[str, PackageProperty] = pydantic.Field( + default_factory=dict + ) install_nodes: list[str] = pydantic.Field(default_factory=list) installed_nodes: list[str] = pydantic.Field(default_factory=list) restart_scene_loop: bool = pydantic.Field(default=False) - + @pydantic.computed_field(description="Whether the package is configured") @property def configured(self) -> bool: """ Whether the package is configured. """ - return all(prop.value is not None for prop in self.package_properties.values() if prop.required) - + return all( + prop.value is not None + for prop in self.package_properties.values() + if prop.required + ) + def properties_for_node(self, node_registry: str) -> dict[str, Any]: """ Get the properties for a node. """ - + return { - prop.name: prop.value for prop in self.package_properties.values() if prop.module == node_registry + prop.name: prop.value + for prop in self.package_properties.values() + if prop.module == node_registry } - - + + class ScenePackageInfo(pydantic.BaseModel): packages: list[PackageData] - + def has_package(self, package_registry: str) -> bool: return any(p.registry == package_registry for p in self.packages) - + def get_package(self, package_registry: str) -> PackageData | None: return next((p for p in self.packages if p.registry == package_registry), None) + # ------------------------------------------------------------------------------------------------ # FUNCTIONS # ------------------------------------------------------------------------------------------------ + async def initialize_scene_package_info(scene: "Scene"): """ Initialize the scene package info. - + This means creation of an empty json file in the scene's info directory. """ filepath = os.path.join(scene.info_dir, SCENE_PACKAGE_INFO_FILENAME) - + # if info dir does not exist, create it if not os.path.exists(scene.info_dir): os.makedirs(scene.info_dir) - + if not os.path.exists(filepath): with open(filepath, "w") as f: json.dump(ScenePackageInfo(packages=[]).model_dump(), f) - + async def get_scene_package_info(scene: "Scene") -> ScenePackageInfo: """ Get the scene package info. - + Returns: ScenePackageInfo: Scene package info. """ filepath = os.path.join(scene.info_dir, SCENE_PACKAGE_INFO_FILENAME) - + # if info dir does not exist, create it if not os.path.exists(scene.info_dir): os.makedirs(scene.info_dir) - + if not os.path.exists(filepath): return ScenePackageInfo(packages=[]) - + with open(filepath, "r") as f: return ScenePackageInfo.model_validate_json(f.read()) async def apply_scene_package_info(scene: "Scene", package_datas: list[PackageData]): - """ Will set the status to installed or not_installed for each package. - + Will update the property values for each installed package. - + Args: scene (Scene): The scene to apply the package info to. package_datas (list[PackageData]): The package data to apply. """ - + scene_package_info = await get_scene_package_info(scene) - + for package_data in package_datas: if scene_package_info.has_package(package_data.registry): package_data.status = "installed" @@ -178,28 +184,33 @@ async def apply_scene_package_info(scene: "Scene", package_datas: list[PackageDa else: package_data.status = "not_installed" + async def list_packages() -> list[PackageData]: """ List all installable packages. - + Returns: list[PackageData]: List of package data. """ - + packages = get_nodes_by_base_type("util/packaging/Package") package_datas = [] - + for package_module_cls in packages: package_module: "Package" = package_module_cls() - + # skip if not installable if not package_module.get_property("installable"): continue - + errors = [] - - install_node_modules = await package_module.get_nodes(lambda node: node.registry == "util/packaging/InstallNodeModule") - promoted_properties = await package_module.get_nodes(lambda node: node.registry == "util/packaging/PromoteConfig") + + install_node_modules = await package_module.get_nodes( + lambda node: node.registry == "util/packaging/InstallNodeModule" + ) + promoted_properties = await package_module.get_nodes( + lambda node: node.registry == "util/packaging/PromoteConfig" + ) install_nodes = [] module_properties = {} @@ -210,21 +221,35 @@ async def list_packages() -> list[PackageData]: module_properties[node_module.registry] = node_module.module_properties install_nodes.append(node_module.registry) - log.debug("package_module", package_module=package_module, module_properties=module_properties, promoted_properties=promoted_properties) + log.debug( + "package_module", + package_module=package_module, + module_properties=module_properties, + promoted_properties=promoted_properties, + ) package_properties = {} - + for promoted_property in promoted_properties: property_name = promoted_property.properties["property_name"] - exposed_property_name = promoted_property.properties["exposed_property_name"] + exposed_property_name = promoted_property.properties[ + "exposed_property_name" + ] target_node_registry = promoted_property.properties["node_registry"] - + try: module_property = module_properties[target_node_registry][property_name] except KeyError: - log.warning("module property not found", target_node_registry=target_node_registry, property_name=property_name, module_properties=module_properties) - errors.append(f"Module property {property_name} not found in {target_node_registry}") + log.warning( + "module property not found", + target_node_registry=target_node_registry, + property_name=property_name, + module_properties=module_properties, + ) + errors.append( + f"Module property {property_name} not found in {target_node_registry}" + ) continue - + package_property = PackageProperty( module=target_node_registry, name=property_name, @@ -236,7 +261,7 @@ async def list_packages() -> list[PackageData]: required=promoted_property.properties.get("required", False), ) package_properties[exposed_property_name] = package_property - + package_data = PackageData( name=package_module.properties["package_name"], author=package_module.properties["author"], @@ -248,172 +273,195 @@ async def list_packages() -> list[PackageData]: restart_scene_loop=package_module.properties["restart_scene_loop"], errors=errors, ) - + package_datas.append(package_data) - + return package_datas + async def get_package_by_registry(package_registry: str) -> PackageData | None: """ Get a package by its registry. - + Args: package_registry (str): The registry of the package to get. """ - + packages = await list_packages() - + return next((p for p in packages if p.registry == package_registry), None) + async def save_scene_package_info(scene: "Scene", scene_package_info: ScenePackageInfo): """ Save the scene package info to the scene's info directory. """ - + # if info dir does not exist, create it if not os.path.exists(scene.info_dir): os.makedirs(scene.info_dir) - + with open(os.path.join(scene.info_dir, SCENE_PACKAGE_INFO_FILENAME), "w") as f: json.dump(scene_package_info.model_dump(), f, indent=4) + async def install_package(scene: "Scene", package_data: PackageData) -> PackageData: """ Install a package to the scene. - + Args: scene (Scene): The scene to install the package to. package_data (PackageData): The package data to install. """ - + scene_package_info = await get_scene_package_info(scene) - + if scene_package_info.has_package(package_data.registry): # already installed return package_data - + package_data.status = "installed" - + scene_package_info.packages.append(package_data) - + await save_scene_package_info(scene, scene_package_info) - + return package_data -async def update_package_properties(scene: "Scene", package_registry: str, package_properties: dict[str, PackageProperty]) -> PackageData | None: + +async def update_package_properties( + scene: "Scene", + package_registry: str, + package_properties: dict[str, PackageProperty], +) -> PackageData | None: """ Update the properties of a package. """ - + scene_package_info = await get_scene_package_info(scene) - + package_data = scene_package_info.get_package(package_registry) - + if not package_data: return - + for property_name, property_data in package_properties.items(): package_data.package_properties[property_name].value = property_data.value - + await save_scene_package_info(scene, scene_package_info) - + return package_data + async def uninstall_package(scene: "Scene", package_registry: str): """ Uninstall a package from the scene. (Removes the package from the scene package info) - + Args: scene (Scene): The scene to uninstall the package from. package_registry (str): The registry of the package to uninstall. """ - + scene_package_info = await get_scene_package_info(scene) - + if not scene_package_info.has_package(package_registry): # not installed return - + package_data = scene_package_info.get_package(package_registry) - - scene_package_info.packages = [p for p in scene_package_info.packages if p.registry != package_registry] - - scene_loop:SceneLoop | None = scene.active_node_graph + + scene_package_info.packages = [ + p for p in scene_package_info.packages if p.registry != package_registry + ] + + scene_loop: SceneLoop | None = scene.active_node_graph if scene_loop: for node_id in package_data.installed_nodes: scene_loop.nodes.pop(node_id, None) - + package_data.installed_nodes = [] - + await save_scene_package_info(scene, scene_package_info) + async def initialize_packages(scene: "Scene", scene_loop: SceneLoop): """ Initialize all installed packages into the scene loop. """ - + try: scene_package_info = await get_scene_package_info(scene) for package_data in scene_package_info.packages: - if not package_data.configured: log.warning("package is not configured", package=package_data.name) continue - + if package_data.errors: log.warning("package information has errors", package=package_data.name) continue - + await initialize_package(scene, scene_loop, package_data) - - except Exception as e: + + except Exception: log.error("initialize_packages failed", error=traceback.format_exc()) - + async def initialize_package( - scene: "Scene", - scene_loop: SceneLoop, + scene: "Scene", + scene_loop: SceneLoop, package_data: PackageData, ): """ Initialize an installed package into the scene loop. - + This is the logic that actually adds the package's nodes to the scene loop through the InstallNodeModule node(s) contained in the package module. - + Args: scene (Scene): The scene to install the package to. scene_loop (SceneLoop): The scene loop to install the package to. package_data (PackageData): The package data to install. """ - + try: for registry in package_data.install_nodes: install_node_cls = get_node(registry) - - node:Node = install_node_cls() + + node: Node = install_node_cls() scene_loop.add_node(node) - - for property_name, property_value in package_data.properties_for_node(registry).items(): + + for property_name, property_value in package_data.properties_for_node( + registry + ).items(): field = node.get_property_field(property_name) field.default = property_value node.properties[property_name] = property_value - log.debug("installed node", registry=registry, properties=package_data.properties_for_node(registry)) - except Exception as e: - log.error("initialize_package failed", error=traceback.format_exc(), package_data=package_data) + log.debug( + "installed node", + registry=registry, + properties=package_data.properties_for_node(registry), + ) + except Exception: + log.error( + "initialize_package failed", + error=traceback.format_exc(), + package_data=package_data, + ) # ------------------------------------------------------------------------------------------------ # NODES # ------------------------------------------------------------------------------------------------ + @register("util/packaging/Package", as_base_type=True) class Package(Graph): """ Configure node that helps managing node module packaging setup for easy scene installation. - + This graph expects node module packaging instructions via various packaging nodes. """ + _export_definition: ClassVar[bool] = False class Fields: @@ -421,40 +469,40 @@ class Package(Graph): name="package_name", description="The name of the node module", type="str", - default="" + default="", ) - + author = PropertyField( name="author", description="The author of the node module", type="str", - default="" + default="", ) - + description = PropertyField( name="description", description="The description of the node module", type="str", - default="" + default="", ) - + installable = PropertyField( name="installable", description="Whether the node module is installable to the scene", type="bool", - default=True + default=True, ) - + restart_scene_loop = PropertyField( name="restart_scene_loop", description="Whether the scene loop should be restarted after the package is installed", type="bool", - default=False + default=False, ) - + def __init__(self, title="Package", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("package_name", "") self.set_property("author", "") @@ -462,18 +510,17 @@ class Package(Graph): self.set_property("installable", True) self.set_property("restart_scene_loop", False) - + @register("util/packaging/InstallNodeModule") class InstallNodeModule(Node): - class Fields: node_registry = PropertyField( name="node_registry", description="The registry path of the node module to package", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: @@ -482,56 +529,56 @@ class InstallNodeModule(Node): title_color="#2e4657", icon="F01A6", ) - + def __init__(self, title="Install Node Module", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("node_registry", UNRESOLVED) - + @register("util/packaging/PromoteConfig") class PromoteConfig(Node): """ Promotes a single module property to be configurable through the scene once the package is installed. """ - + class Fields: node_registry = PropertyField( name="node_registry", description="The registry path of the node module to package", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + property_name = PropertyField( name="property_name", description="Property Name", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + exposed_property_name = PropertyField( name="exposed_property_name", description="Exposed Property Name", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + label = PropertyField( name="label", description="Label", type="str", default="", ) - + required = PropertyField( name="required", description="Whether the property is required", type="bool", default=False, ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: @@ -541,10 +588,10 @@ class PromoteConfig(Node): def __init__(self, title="Promote Config", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("node_registry", UNRESOLVED) self.set_property("property_name", UNRESOLVED) self.set_property("exposed_property_name", UNRESOLVED) self.set_property("required", False) - self.set_property("label", "") \ No newline at end of file + self.set_property("label", "") diff --git a/src/talemate/game/engine/nodes/prompt.py b/src/talemate/game/engine/nodes/prompt.py index bc61ad4b..1172df59 100644 --- a/src/talemate/game/engine/nodes/prompt.py +++ b/src/talemate/game/engine/nodes/prompt.py @@ -1,21 +1,18 @@ -import inspect import pydantic import structlog -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from talemate.game.engine.nodes.core import ( - Node, - register, - GraphState, - InputValueError, - PropertyField, + Node, + register, + GraphState, + InputValueError, + PropertyField, NodeVerbosity, NodeStyle, - UNRESOLVED, TYPE_CHOICES, ) from talemate.agents.registry import get_agent_types -from talemate.agents.base import Agent, set_processing -from talemate.instance import get_agent +from talemate.agents.base import Agent from talemate.prompts.base import Prompt, PrependTemplateDirectories from talemate.client.presets import make_kind from talemate.context import active_scene @@ -26,293 +23,298 @@ if TYPE_CHECKING: log = structlog.get_logger("talemate.game.engine.nodes.prompt") -TYPE_CHOICES.extend([ - "prompt", -]) +TYPE_CHOICES.extend( + [ + "prompt", + ] +) + @register("prompt/PromptFromTemplate") class PromptFromTemplate(Node): """ Loads a talemate template prompt - + Inputs: - + - template_file: The template file to load - variables: The variables to use in the template (optional) - + Properties: - + - scope: the template scope (choices of agents or scene) - + Outputs: - + - prompt: The Prompt instance """ - + class Fields: - scope = PropertyField( name="scope", type="str", generate_choices=lambda: ["scene"] + list(get_agent_types()), description="The template scope", - default="scene" + default="scene", ) - + template_file = PropertyField( name="template_file", type="str", description="The template to load", default="", ) - + template_text = PropertyField( name="template_text", type="text", description="The template text to use", default="", ) - + def __init__(self, title="Prompt From Template", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("template_file", socket_type="str", optional=True) self.add_input("template_text", socket_type="str", optional=True) self.add_input("variables", socket_type="dict", optional=True) - + self.set_property("scope", "scene") self.set_property("template_file", "") self.set_property("template_text", "") - + self.add_output("prompt", socket_type="prompt") - async def run(self, graph_state: GraphState): template_file = self.normalized_input_value("template_file") template_text = self.normalized_input_value("template_text") variables = self.normalized_input_value("variables") or {} scope = self.get_property("scope") - + if template_file and template_text: - raise InputValueError(self, "template_file", "Cannot provide both template_file and template_text") - + raise InputValueError( + self, + "template_file", + "Cannot provide both template_file and template_text", + ) + if template_file: if scope != "scene": template_uid = f"{scope}.{template_file}" else: template_uid = template_file - + prompt: Prompt = Prompt.get(template_uid, vars=variables) elif template_text: prompt: Prompt = Prompt.from_text(template_text, vars=variables) else: - raise InputValueError(self, "template_file", "Must provide either template_file or template_text") - - self.set_output_values({ - "prompt": prompt - }) - + raise InputValueError( + self, + "template_file", + "Must provide either template_file or template_text", + ) + + self.set_output_values({"prompt": prompt}) + + @register("prompt/RenderPrompt") class RenderPrompt(Node): """ Renders a prompt Input: - + - prompt: The prompt to render - + Outputs: - + - rendered: The rendered prompt - """ - + """ + def __init__(self, title="Render Prompt", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("prompt", socket_type="prompt") - + self.add_output("rendered", socket_type="str") - + async def run(self, graph_state: GraphState): - prompt:Prompt = self.require_input("prompt") + prompt: Prompt = self.require_input("prompt") rendered = prompt.render() - - self.set_output_values({ - "rendered": rendered, - }) - + + self.set_output_values( + { + "rendered": rendered, + } + ) + @register("prompt/TemplateVariables") class TemplateVariables(Node): """ Helper node that returns a pre defined set of common template variables - + Variables: - + - scene: The current scene - max_tokens: The maximum number of tokens in the response - + Inputs: - + - agent: The relevant agent - merge_with: A dictionary of variables to merge with the pre defined variables (optional) - + Outputs: - + - variables: A dictionary of variables - agent: The relevant agent """ - + def __init__(self, title="Template Variables", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("agent", socket_type="agent") self.add_input("merge_with", socket_type="dict", optional=True) self.add_output("variables", socket_type="dict") self.add_output("agent", socket_type="agent") - + async def run(self, graph_state: GraphState): - agent:Agent = self.require_input("agent") - merge_with:dict = self.normalized_input_value("merge_with") or {} + agent: Agent = self.require_input("agent") + merge_with: dict = self.normalized_input_value("merge_with") or {} if not hasattr(agent, "client"): raise InputValueError(self, "agent", "Agent does not have a client") - - + scene = active_scene.get() - + variables = { "scene": scene, "scene_title": scene.title or scene.name, "max_tokens": agent.client.max_token_length, } - + variables.update(merge_with) - - self.set_output_values({ - "variables": variables, - "agent": agent - }) + + self.set_output_values({"variables": variables, "agent": agent}) + @register("prompt/GenerateResponse") class GenerateResponse(Node): """ Sends a prompt to the agent and generates a response - + Inputs: - + - agent: The agent to send the prompt to - prompt: The prompt to send to the agent - + Properties - + - data_output: Output the response as data structure - attempts: The number of attempts to attempt (on empty response) - + Outputs: - + - response: The response from the agent - data_obj: The data structure of the response - rendered_prompt: The rendered prompt - agent: The agent that generated the response - + """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#392c34", title_color="#572e44", - icon="F1719", #robot-happy + icon="F1719", # robot-happy ) - + class Fields: data_output = PropertyField( name="data_output", type="bool", default=False, - description="Output the response as a data structure" + description="Output the response as a data structure", ) - + attempts = PropertyField( name="attempts", type="int", description="The number of attempts (retry on empty response)", default=1, ) - + response_length = PropertyField( name="response_length", type="int", description="The maximum length of the response", default=256, ) - + action_type = PropertyField( name="action_type", type="str", description="Classification of the generated response", - choices=sorted([ - "conversation", - "narrate", - "create", - "scene_direction", - "analyze", - "edit", - "world_state", - "summarize", - ]), - default="scene_direction" + choices=sorted( + [ + "conversation", + "narrate", + "create", + "scene_direction", + "analyze", + "edit", + "world_state", + "summarize", + ] + ), + default="scene_direction", ) - - + def __init__(self, title="Generate Response", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("agent", socket_type="agent") self.add_input("prompt", socket_type="prompt") self.add_input("action_type", socket_type="str", optional=True) - + self.set_property("data_output", False) self.set_property("response_length", 256) self.set_property("action_type", "scene_direction") self.set_property("attempts", 1) - + self.add_output("response", socket_type="str") self.add_output("data_obj", socket_type="dict") self.add_output("rendered_prompt", socket_type="str") self.add_output("agent", socket_type="agent") - - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() - agent:Agent = self.require_input("agent") - prompt:Prompt = self.require_input("prompt") + scene: "Scene" = active_scene.get() + agent: Agent = self.require_input("agent") + prompt: Prompt = self.require_input("prompt") action_type = self.get_property("action_type") response_length = self.get_property("response_length") data_output = self.get_property("data_output") attempts = self.get_property("attempts") or 1 - + prompt.agent_type = agent.agent_type - + kind = make_kind( - action_type=action_type, - length=response_length, - expect_json=data_output + action_type=action_type, length=response_length, expect_json=data_output ) - + if data_output: prompt.data_response = True - + if state.verbosity >= NodeVerbosity.NORMAL: log.info(f"Sending prompt to agent {agent.agent_type} with kind {kind}") - + async def send_prompt(*args, **kwargs): prompt.vars.update( { @@ -320,71 +322,71 @@ class GenerateResponse(Node): } ) return await prompt.send(agent.client, kind=kind) - send_prompt.__name__= self.title.replace(" ", "_").lower() - + + send_prompt.__name__ = self.title.replace(" ", "_").lower() + with PrependTemplateDirectories(scene.template_dir): for _ in range(attempts): response = await agent.delegate(send_prompt) if response: break - + if isinstance(response, tuple): response, data_obj = response else: data_obj = None - - self.set_output_values({ - "response": response.strip(), - "data_obj": data_obj, - "rendered_prompt": prompt.prompt, - "agent": agent - }) - - + + self.set_output_values( + { + "response": response.strip(), + "data_obj": data_obj, + "rendered_prompt": prompt.prompt, + "agent": agent, + } + ) + + @register("prompt/CleanResponse") class CleanResponse(Node): """ Cleans a response - + Inputs: - + - response: The response to clean - + Properties: - + - partial_sentences: Strip partial sentences from the response - + Outputs: - + - cleaned: The cleaned response """ - + class Fields: strip_partial_sentences = PropertyField( name="partial_sentences", type="bool", description="Strip partial sentences from the response", - default=True + default=True, ) - + def __init__(self, title="Clean Response", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("response", socket_type="str") - + self.set_property("partial_sentences", True) - + self.add_output("cleaned", socket_type="str") - + async def run(self, graph_state: GraphState): response = self.require_input("response") partial_sentences = self.get_property("partial_sentences") - + if partial_sentences: response = strip_partial_sentences(response) - - self.set_output_values({ - "cleaned": response - }) - \ No newline at end of file + + self.set_output_values({"cleaned": response}) diff --git a/src/talemate/game/engine/nodes/raise_errors.py b/src/talemate/game/engine/nodes/raise_errors.py index 4721d8cf..bbd7f293 100644 --- a/src/talemate/game/engine/nodes/raise_errors.py +++ b/src/talemate/game/engine/nodes/raise_errors.py @@ -24,46 +24,47 @@ __all__ = [ log = structlog.get_logger("talemate.game.engine.nodes.core.raise") + @register("raise/ActedAsCharacter") class ActedAsCharacter(Node): """ Raises an ActedAsCharacter exception. - + This is used to communicate to the main scene loop that the user has performed an action as a specific character. - + Inputs: - + - state: The current graph state - character_name: The name of the character the user acted as """ - + def __init__(self, title="Acted As Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character_name", socket_type="str") - + async def run(self, state: GraphState): character_name = self.get_input_value("character_name") - + raise exceptions.ActedAsCharacter(character_name) - + @register("raise/Stop") class Stop(Node): """ Raises the sepcified node / scene loop exception to stop execution of the current graph - + Inputs: - + - state: The current state - exception: The exception to raise - + Outputs: - + - state: The current state """ @@ -73,9 +74,9 @@ class Stop(Node): return NodeStyle( node_color="#401a1a", title_color="#111", - icon="F0028", #alert-circle + icon="F0028", # alert-circle ) - + class Fields: exception = PropertyField( name="exception", @@ -91,30 +92,27 @@ class Stop(Node): "ExitScene", "RestartSceneLoop", "ResetScene", - ] + ], ) - + def __init__(self, title="Stop", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("exception", type="str", optional=True) - + self.set_property("exception", "StopGraphExecution") - + self.add_output("state") - + async def run(self, state: GraphState): - exception = self.require_input("exception") - + # this will never be reached, but it's here to make sure the # that Stage nodes can be connected to this node - self.set_output_values({ - "state": self.get_input_value("state") - }) - + self.set_output_values({"state": self.get_input_value("state")}) + if exception == "StopGraphExecution": raise StopGraphExecution() elif exception == "StopModule": @@ -134,26 +132,26 @@ class Stop(Node): else: raise InputValueError(self, "exception", f"Unknown exception: {exception}") - + @register("raise/InputValueError") class InputValueErrorNode(Node): """ Raises an InputValueError exception. - + Inputs: - + - state: The current state - field: The field that caused the error - message: The message to raise the exception with """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#401a1a", title_color="#111", - icon="F0028", #alert-circle + icon="F0028", # alert-circle ) class Fields: @@ -163,36 +161,33 @@ class InputValueErrorNode(Node): type="str", default="", ) - + field = PropertyField( name="field", description="The field that caused the error", type="str", default="", ) - + def __init__(self, title="Input Value Error", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("message", socket_type="str") self.add_input("field", socket_type="str") - + self.set_property("message", "") self.set_property("field", "") - + self.add_output("state") - + async def run(self, state: GraphState): message = self.require_input("message") field = self.require_input("field") # this will never be reached, but it's here to make sure the # that Stage nodes can be connected to this node - self.set_output_values({ - "state": self.get_input_value("state") - }) - + self.set_output_values({"state": self.get_input_value("state")}) + raise InputValueError(self, field, message) - \ No newline at end of file diff --git a/src/talemate/game/engine/nodes/registry.py b/src/talemate/game/engine/nodes/registry.py index 1f68e94d..ea206d12 100644 --- a/src/talemate/game/engine/nodes/registry.py +++ b/src/talemate/game/engine/nodes/registry.py @@ -6,7 +6,6 @@ import traceback from typing import TYPE_CHECKING from pathlib import Path from talemate.context import active_scene -from talemate.util.data import JSONEncoder from talemate.game.engine.nodes.base_types import base_node_type from talemate.game.engine.nodes import SEARCH_PATHS @@ -16,18 +15,18 @@ if TYPE_CHECKING: from talemate.tale_mate import Scene __all__ = [ - 'register', - 'get_node', - 'NODES', - 'export_node_definitions', - 'import_node_definitions', - 'import_node_definition', - 'import_scene_node_definitions', - 'import_talemate_node_definitions', - 'normalize_registry_name', - 'get_nodes_by_base_type', - 'validate_registry_path', - 'NodeNotFoundError', + "register", + "get_node", + "NODES", + "export_node_definitions", + "import_node_definitions", + "import_node_definition", + "import_scene_node_definitions", + "import_talemate_node_definitions", + "normalize_registry_name", + "get_nodes_by_base_type", + "validate_registry_path", + "NodeNotFoundError", ] log = structlog.get_logger("talemate.game.engine.nodes.registry") @@ -36,26 +35,28 @@ NODES = {} INITIAL_IMPORT_DONE = False + class NodeNotFoundError(ValueError): pass -def normalize_registry_name(name:str) -> str: + +def normalize_registry_name(name: str) -> str: """ Will normalize a registry name to a consistent format of camel case with no spaces or special characters. - + Arguments: name (str): Name to normalize - + Examples: - + - "My Node" -> "myNode" - "My-Node" -> "myNode" - - "My Other Node" -> "myOtherNode" + - "My Other Node" -> "myOtherNode" """ - + name = name.title() - + # first letter lowercase name = name[0].lower() + name[1:] @@ -64,162 +65,161 @@ def normalize_registry_name(name:str) -> str: return name + def get_node(name): if not name: return None - - scene:"Scene" = active_scene.get() - + + scene: "Scene" = active_scene.get() + SCENE_NODES = getattr(scene, "_NODE_DEFINITIONS", {}) - + if name in SCENE_NODES: return SCENE_NODES[name] - + cls = NODES.get(name) if not cls: raise NodeNotFoundError(f"Node type '{name}' not found") return cls -def get_nodes_by_base_type(base_type:str) -> list["NodeBase"]: + +def get_nodes_by_base_type(base_type: str) -> list["NodeBase"]: """ Returns a list of all nodes that have the given base type. - + Will check both the scene and the talemate node definitions. - + Scene nodes take priority if both register the same node. """ - - scene:"Scene" = active_scene.get() - + + scene: "Scene" = active_scene.get() + SCENE_NODES = getattr(scene, "_NODE_DEFINITIONS", {}) - + nodes = {} - - + for node_name, node_cls in NODES.items(): if node_cls._base_type == base_type: nodes[node_name] = node_cls - + for node_name, node_cls in SCENE_NODES.items(): if node_cls._base_type == base_type: nodes[node_name] = node_cls - + return list(nodes.values()) + class register: - def __init__(self, name, as_base_type:bool=False, container:dict|None=None): + def __init__(self, name, as_base_type: bool = False, container: dict | None = None): self.name = name self.as_base_type = as_base_type self.container = container - + if self.container is None: self.container = NODES - def __call__(self, cls): self.container[self.name] = cls cls._registry = self.name - + if self.as_base_type: base_node_type(self.name)(cls) return cls -def validate_registry_path(path:str, node_definitions:dict | None = None): + +def validate_registry_path(path: str, node_definitions: dict | None = None): """ Validates a registry path to ensure it is a valid path. - + Arguments: - + - path (str): The registry path to validate - node_definitions (dict): The node definitions to validate against - + Raises: - + - ValueError: if the path is invalid """ - + if not node_definitions: node_definitions = export_node_definitions() - + if not path: raise ValueError("Empty registry path") - + # path needs to have at least one / with two parts parts = path.split("/") if len(parts) < 2: - raise ValueError("Registry path must contain at least two parts (e.g., 'my/node')") - + raise ValueError( + "Registry path must contain at least two parts (e.g., 'my/node')" + ) + # the path can not be the prefix of an existing path # e.g., can't put a node where a path is already registered for existing_path in node_definitions["nodes"].keys(): if existing_path.startswith(path + "/"): raise ValueError(f"Registry path {path} is colliding with {existing_path}") - - + + def export_node_definitions() -> dict: - export = { - "nodes": [] - } - - scene:"Scene" = active_scene.get() - - libraries = { - **NODES - } - + export = {"nodes": []} + + scene: "Scene" = active_scene.get() + + libraries = {**NODES} + if hasattr(scene, "_NODE_DEFINITIONS"): - libraries.update(scene._NODE_DEFINITIONS) - + libraries.update(scene._NODE_DEFINITIONS) + for name, node_cls in libraries.items(): - try: - node:"NodeBase" = node_cls() - except ValueError as exc: - log.warning("export_node_definitions: failed to instantiate node class", name=name) + node: "NodeBase" = node_cls() + except ValueError: + log.warning( + "export_node_definitions: failed to instantiate node class", name=name + ) continue - + if not node._export_definition: continue - + field_defs = {} - + for prop_name in node.properties.keys(): field_defs[prop_name] = node.get_property_field(prop_name).model_dump() - + if hasattr(node, "module_properties"): for prop_name, prop_data in node.module_properties.items(): field_defs[prop_name] = prop_data.model_dump() - - exported_node = { - "fields": field_defs, - **node.model_dump() - } - + + exported_node = {"fields": field_defs, **node.model_dump()} + exported_node.pop("nodes", None) exported_node.pop("edges", None) - + export["nodes"].append(exported_node) - + # sort export["nodes"] = { n["registry"]: n for n in sorted(export["nodes"], key=lambda x: x["registry"]) } - - #with open("exported_nodes.json", "w") as file: + + # with open("exported_nodes.json", "w") as file: # json.dump(export, file, indent=2, cls=JSONEncoder) - + return export + def import_initial_node_definitions(): global INITIAL_IMPORT_DONE if INITIAL_IMPORT_DONE: return - + import_talemate_node_definitions() INITIAL_IMPORT_DONE = True + def import_talemate_node_definitions(): - retry = [] files = [] @@ -228,16 +228,16 @@ def import_talemate_node_definitions(): for path in base_path.rglob("*.json"): if path.is_file(): files.append(str(path)) - #log.debug("import_talemate_node_definitions: found node definition", path=path) - + # log.debug("import_talemate_node_definitions: found node definition", path=path) + for filepath in files: with open(filepath, "r") as file: data = json.load(file) try: node_cls = import_node_definition(data) node_cls._module_path = filepath - except Exception as exc: - retry.append( (data, filepath) ) + except Exception: + retry.append((data, filepath)) attempt_retry = True while retry and attempt_retry: @@ -246,49 +246,59 @@ def import_talemate_node_definitions(): try: node_cls = import_node_definition(data) node_cls._module_path = filepath - retry.remove( (data, filepath) ) + retry.remove((data, filepath)) attempt_retry = True - except Exception as exc: - log.error("import_talemate_node_definitions: failed to import node definition", data=data["registry"], exc=traceback.format_exc()) + except Exception: + log.error( + "import_talemate_node_definitions: failed to import node definition", + data=data["registry"], + exc=traceback.format_exc(), + ) pass -def import_scene_node_definitions(scene:"Scene"): + +def import_scene_node_definitions(scene: "Scene"): scene._NODE_DEFINITIONS = {} - + # loop files in scene.nodes_dir # and register the ones that have 'registry' specified # at the root level of the json file - + if not os.path.exists(scene.nodes_dir): return - + retries = [] - + for filename in os.listdir(scene.nodes_dir): - if not filename.endswith(".json"): continue - - #log.debug("import_scene_node_definitions: importing node definition", filename=filename) - + + # log.debug("import_scene_node_definitions: importing node definition", filename=filename) + if filename == scene.nodes_filename: - log.warning("import_scene_node_definitions: skipping scene nodes file", filename=filename) + log.warning( + "import_scene_node_definitions: skipping scene nodes file", + filename=filename, + ) continue - + filepath = os.path.join(scene.nodes_dir, filename) - + with open(filepath, "r") as file: data = json.load(file) - + if not data.get("registry"): - log.warning("import_scene_node_definitions: node definition missing registry, skipping", filename=filename) + log.warning( + "import_scene_node_definitions: node definition missing registry, skipping", + filename=filename, + ) continue try: node_cls = import_node_definition(data, scene._NODE_DEFINITIONS) node_cls._module_path = filepath - except Exception as exc: - retries.append( (data, filepath) ) - + except Exception: + retries.append((data, filepath)) + attempt_retry = True while retries and attempt_retry: attempt_retry = False @@ -296,26 +306,31 @@ def import_scene_node_definitions(scene:"Scene"): try: node_cls = import_node_definition(data, scene._NODE_DEFINITIONS) node_cls._module_path = filepath - retries.remove( (data, filepath) ) + retries.remove((data, filepath)) attempt_retry = True - except Exception as exc: - log.error("import_scene_node_definitions: failed to import node definition", data=data["registry"], exc=traceback.format_exc()) + except Exception: + log.error( + "import_scene_node_definitions: failed to import node definition", + data=data["registry"], + exc=traceback.format_exc(), + ) pass - - - -def import_node_definitions(data:dict): + + +def import_node_definitions(data: dict): for node_data in data["nodes"]: import_node_definition(node_data) -def import_node_definition(node_data:dict, registry=None, reimport:bool=False) -> "NodeBase": - + +def import_node_definition( + node_data: dict, registry=None, reimport: bool = False +) -> "NodeBase": """ Imports a node definition from a dictionary and registers it in the NODES registry as a class. - + Arguments: - + - node_data (dict): The node definition data - registry (dict): The registry to register the node class in - defaults to NODES - reimport (bool): If True, will reimport the node class if it already exists in the registry, removing the old one first. @@ -325,24 +340,24 @@ def import_node_definition(node_data:dict, registry=None, reimport:bool=False) - if registry is None: registry = NODES - + if reimport: registry.pop(node_data["registry"], None) - + try: node_cls = registry[node_data["registry"]] except KeyError: node_cls = dynamic_node_import(node_data, node_data["registry"], registry) - + node = node_cls() - + if "fields" in node_data: for prop_name, prop_data in node_data["fields"].items(): field = node.get_property_field(prop_name) field.model_validate(prop_data) - + node.model_validate(node_data) - + registry[node_data["registry"]] = node_cls - + return node_cls diff --git a/src/talemate/game/engine/nodes/run.py b/src/talemate/game/engine/nodes/run.py index 3fc36eb5..4aebe104 100644 --- a/src/talemate/game/engine/nodes/run.py +++ b/src/talemate/game/engine/nodes/run.py @@ -31,9 +31,12 @@ async_signals.register( "nodes_breakpoint", ) -TYPE_CHOICES.extend([ - "exception", -]) +TYPE_CHOICES.extend( + [ + "exception", + ] +) + @dataclasses.dataclass class BreakpointEvent: @@ -41,61 +44,63 @@ class BreakpointEvent: state: GraphState module_path: str = None + class FunctionWrapper: - def __init__(self, endpoint:Node, containing_graph:Graph, state:GraphState): + def __init__(self, endpoint: Node, containing_graph: Graph, state: GraphState): self.state = state self.containing_graph = containing_graph self.endpoint = endpoint async def __call__(self, **kwargs): - result = None - async def handle_result(state:GraphState): + result = None + + async def handle_result(state: GraphState): nonlocal result - result = state.data.get(f"__fn_result") + result = state.data.get("__fn_result") if state.verbosity >= NodeVerbosity.VERBOSE: - log.info(f"Function result", result=result, node=self.endpoint) - + log.info("Function result", result=result, node=self.endpoint) + if self.endpoint != self.containing_graph: - # endpoint is not the containing graph, but a subgraph # we need to find the function arguments and set their values # # only arguments connected to the endpoint node are considered - + argument_nodes = await self.containing_graph.get_nodes_connected_to( self.endpoint, fn_filter=lambda node: isinstance(node, FunctionArgument) ) - + await self.containing_graph.execute_to_node( - self.endpoint, self.state, callbacks=[handle_result], + self.endpoint, + self.state, + callbacks=[handle_result], state_values={ f"{arg.id}__fn_arg_value": kwargs.get(arg.get_property("name")) for arg in argument_nodes }, - execute_forks=True + execute_forks=True, ) else: # endpoint is the containing graph # we need to find the function arguments and set their values # # all arguments are considered - + argument_nodes = await self.containing_graph.get_nodes( fn_filter=lambda node: isinstance(node, FunctionArgument) ) - + await self.containing_graph.execute( - self.state, + self.state, callbacks=[handle_result], state_values={ f"{arg.id}__fn_arg_value": kwargs.get(arg.get_property("name")) for arg in argument_nodes }, ) - - + return result - + async def get_argument_nodes(self): if self.endpoint != self.containing_graph: return await self.containing_graph.get_nodes_connected_to( @@ -105,22 +110,23 @@ class FunctionWrapper: return await self.containing_graph.get_nodes( fn_filter=lambda node: isinstance(node, FunctionArgument) ) - + @register("core/functions/Argument") class FunctionArgument(Node): """ Represents an argument to a function. - + Properties: - + - type (str): The type of the argument - name (str): The name of the argument - + Outputs: - + - value: The value of the argument (during function execution) """ + class Fields: typ = PropertyField( type="str", @@ -134,7 +140,7 @@ class FunctionArgument(Node): "bool", ], ) - + name = PropertyField( type="str", name="name", @@ -148,73 +154,73 @@ class FunctionArgument(Node): return NodeStyle( node_color="#2d2c39", title_color="#312e57", - icon="F0AE7", #variable - auto_title="{name}" + icon="F0AE7", # variable + auto_title="{name}", ) - + def __init__(self, title="Argument", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("name", UNRESOLVED) self.set_property("typ", "str") self.add_output("value") - - async def run(self, state:GraphState): + + async def run(self, state: GraphState): value = state.data.get(f"{self.id}__fn_arg_value", UNRESOLVED) - + self.set_output_values({"value": value}) + @register("core/functions/Return") class FunctionReturn(Node): """ Represents the return value of a function. - + Inputs: - + - value: The value to return - + Outputs: - + - value: The value to return """ - + def __init__(self, title="Return", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("value") self.add_output("value") - - async def run(self, state:GraphState): + + async def run(self, state: GraphState): value = self.get_input_value("value") - + if value is UNRESOLVED: return - + self.set_output_values({"value": value}) - state.data[f"__fn_result"] = value - + state.data["__fn_result"] = value + if state.verbosity >= NodeVerbosity.VERBOSE: log.info(f"Function return: {self.id}", value=value) - + raise StopGraphExecution(f"Function return: {self.id}") - @register("core/functions/DefineFunction") class DefineFunction(Node): """ Does not define any outputs and is considered an isolated node. - + The correspinding GetFunction node will be used to retrieve the function object. - + Inputs: - + - nodes: The nodes to convert into a function - name: The name of the function """ - + _isolated: ClassVar[bool] = True class Fields: @@ -225,21 +231,19 @@ class DefineFunction(Node): default=UNRESOLVED, ) - @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#392f2c", title_color="#573a2e", - icon="F0295", #function - auto_title="DEF {name}" + icon="F0295", # function + auto_title="DEF {name}", ) - def __init__(self, title="Define Function", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("nodes") self.add_input("name", socket_type="str") @@ -248,36 +252,36 @@ class DefineFunction(Node): @property def never_run(self) -> bool: return True - - async def run(self, state:GraphState): - return - - async def get_function(self, state:GraphState) -> FunctionWrapper: + async def run(self, state: GraphState): + return + + async def get_function(self, state: GraphState) -> FunctionWrapper: input_socket = self.get_input_socket("nodes") if not input_socket.source: raise ValueError("Nodes input not connected") - + input_node = input_socket.source.node - + return FunctionWrapper(input_node, state.graph, state) + @register("core/functions/GetFunction") class GetFunction(Node): """ Retrieves a function from the graph - + This has no inputs and will return the function wrapper for the function defined by the DefineFunction node. - + Properties: - + - name: The name of the function Outputs: - + - fn: The function wrapper """ @@ -295,22 +299,25 @@ class GetFunction(Node): return NodeStyle( node_color="#392f2c", title_color="#573a2e", - icon="F0295", #function - auto_title="FN {name}" + icon="F0295", # function + auto_title="FN {name}", ) def __init__(self, title="Get Function", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("name", UNRESOLVED) self.add_output("fn", socket_type="function") - async def run(self, state:GraphState): + async def run(self, state: GraphState): name = self.require_input("name") - graph:Graph = state.graph - define_function_node = await graph.get_node(fn_filter=lambda node: isinstance(node, DefineFunction) and node.get_property("name") == name) + graph: Graph = state.graph + define_function_node = await graph.get_node( + fn_filter=lambda node: isinstance(node, DefineFunction) + and node.get_property("name") == name + ) if not define_function_node: raise ValueError(f"Function {name} not found") @@ -318,25 +325,26 @@ class GetFunction(Node): fn_wrapper = await define_function_node.get_function(state) self.set_output_values({"fn": fn_wrapper}) - + return fn_wrapper + @register("core/functions/CallFunction") class CallFunction(Node): """ Takes a function wrapper input and a dict property to define arguments to pass to the function then calls the function - + Inputs: - + - fn: The function to call - args: The arguments to pass to the function - + Outputs: - + - result: The result of the function call """ - + class Fields: args = PropertyField( type="dict", @@ -344,51 +352,52 @@ class CallFunction(Node): description="The arguments to pass to the function", default={}, ) - + def __init__(self, title="Call Function", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("fn", socket_type="function") self.add_input("args", optional=True) self.set_property("args", {}) self.add_output("result") - - async def run(self, state:GraphState): + + async def run(self, state: GraphState): fn = self.get_input_value("fn") args = self.get_input_value("args") - + if not isinstance(fn, FunctionWrapper): raise ValueError("fn must be a FunctionWrapper instance") - + result = await fn(**args) - + self.set_output_values({"result": result}) + @register("core/functions/CallForEach") class CallForEach(Node): """ Calls the supplied function on each item in the input list - + The item is passed to the function as an argument - + Inputs: - + - state: The state of the graph - fn: The function to call - items: The list of items to iterate over - + Properties: - + - copy_items: Whether to copy the items list (default: False) - argument_name: The name of the argument to pass to the function (default: item) - + Outputs: - + - state: The state of the graph - results: The results of the function calls """ - + class Fields: copy_items = PropertyField( type="bool", @@ -396,50 +405,51 @@ class CallForEach(Node): description="Whether to copy the items list", default=False, ) - + argument_name = PropertyField( type="str", name="argument_name", description="The name of the argument to pass to the function", default="item", ) + def __init__(self, title="Call For Each", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("fn", socket_type="function") self.add_input("items", socket_type="list") - + self.set_property("copy_items", False) self.set_property("argument_name", "item") self.add_output("state") self.add_output("results", socket_type="list") - - async def run(self, state:GraphState): + + async def run(self, state: GraphState): fn = self.get_input_value("fn") items = self.get_input_value("items") argument_name = self.get_property("argument_name") copy_items = self.get_property("copy_items") - + if not argument_name: raise InputValueError(self, "argument_name", "Argument name is required") - + if not isinstance(fn, FunctionWrapper): raise InputValueError(self, "fn", "fn must be a FunctionWrapper instance") - + if not isinstance(items, list): raise InputValueError(self, "items", "items must be a list") - + results = [] - + if copy_items: items = items.copy() - + for item in items: result = await fn(**{argument_name: item}) results.append(result) - + self.set_output_values( { "state": self.get_input_value("state"), @@ -447,31 +457,31 @@ class CallForEach(Node): } ) + @base_node_type("core/functions/Function") class Function(Graph): """ A module graph that defines a function """ - + @pydantic.computed_field(description="Inputs") @property def inputs(self) -> list[Socket]: - """ Function graphs never have any direct inputs """ return [] - - @pydantic.computed_field(description="Outputs") + + @pydantic.computed_field(description="Outputs") @property def outputs(self) -> list[Socket]: """ Function graphs only have one output which is the function wrapper """ - + if hasattr(self, "_outputs"): return self._outputs - + self._outputs = [ Socket( name="fn", @@ -479,29 +489,27 @@ class Function(Graph): node=self, ) ] - - return self._outputs + + return self._outputs @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: - # If a style is defined in the graph it overrides the default defined_style = super().style if defined_style: return defined_style - + return NodeStyle( node_color="#392f2c", title_color="#573a2e", - icon="F0295", #function + icon="F0295", # function ) - - + async def run(self, state: GraphState): """ Executing the graph will return a FunctionWrapper object where - the endpoint node is an OutputSocket node + the endpoint node is an OutputSocket node """ wrapped = FunctionWrapper(self, self, state) self.set_output_values({"fn": wrapped}) @@ -511,48 +519,56 @@ class Function(Graph): class RunModule(Node): """ Provides a way to run a node module from memory - + Inputs: - module (optional) - + Outputs: - done: True if module was executed successfully - failed: Error message if module execution failed - cancelled: True if module execution was cancelled """ - + def __init__(self, title="Run Module", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("module") self.add_output("done", socket_type="bool") self.add_output("failed", socket_type="str") self.add_output("cancelled", socket_type="bool") - + async def run(self, state: GraphState): module = self.get_input_value("module") - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Running module") - + if not isinstance(module, Graph): raise ValueError("Module must be a Graph instance") - + if state.outer.data.get("_in_run_module") == module: - raise ValueError(f"Infinite loop detected. Running module from within itself: {self.title}") - + raise ValueError( + f"Infinite loop detected. Running module from within itself: {self.title}" + ) + task_key = f"__run_{module.id}" try: state.data["_in_run_module"] = module - + quaratined_state = GraphState() - quaratined_state.shared["creative_mode"] = state.shared.get("creative_mode", False) - quaratined_state.shared["nested_scene_loop"] = module.base_type == "scene/SceneLoop" + quaratined_state.shared["creative_mode"] = state.shared.get( + "creative_mode", False + ) + quaratined_state.shared["nested_scene_loop"] = ( + module.base_type == "scene/SceneLoop" + ) quaratined_state.stack = state.stack - - task = state.shared[task_key] = asyncio.create_task(module.run(quaratined_state)) - + + task = state.shared[task_key] = asyncio.create_task( + module.run(quaratined_state) + ) + try: await task self.set_output_values({"done": True}) @@ -580,14 +596,14 @@ class RunModule(Node): await task except (asyncio.CancelledError, Exception): pass # Ignore any errors during cleanup - - + + @register("core/functions/Breakpoint") class Breakpoint(Node): """ A node that will pause execution of the graph and allow for inspection """ - + class Fields: active = PropertyField( type="bool", @@ -595,147 +611,145 @@ class Breakpoint(Node): description="Whether the breakpoint is active", default=True, ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#461515", - icon="F03C3", #octagon + icon="F03C3", # octagon ) - + def __init__(self, title="Breakpoint", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") - + self.set_property("active", True) - + self.add_output("state") - - async def run(self, state:GraphState): - + + async def run(self, state: GraphState): incoming_state = self.get_input_value("state") active = self.get_property("active") scene = active_scene.get() - + if scene.environment != "creative": active = False log.debug("Breakpoint disabled in non-creative environment", node=self.id) - + if not active: self.set_output_values({"state": incoming_state}) return - + scene = active_scene.get() - + state.shared["__breakpoint"] = self.id if state.verbosity >= NodeVerbosity.NORMAL: log.info("Breakpoint", node=self.id) - + await async_signals.get("nodes_breakpoint").send( BreakpointEvent( - node=self, - state=incoming_state, - module_path=state.graph._module_path + node=self, state=incoming_state, module_path=state.graph._module_path ) ) - + while state.shared.get("__breakpoint"): if scene and not scene.active: log.warning("Breakpoint cancelled", node=self.id) self.set_output_values({"state": state}) raise StopGraphExecution("Breakpoint cancelled") await asyncio.sleep(0.5) - + if state.verbosity >= NodeVerbosity.NORMAL: log.info("Breakpoint released", node=self.id) - + self.set_output_values({"state": incoming_state}) - + @register("core/ErrorHandler") class ErrorHandler(Node): """ A node that will catch unhandled errors in the graph and allow for custom error handling - + Inputs: - + - fn: The function to call when an error occurs """ - + _isolated: ClassVar[bool] = True - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( node_color="#2c0a0a", title_color="#461515", - icon="F05D6", #alert-circle-outline + icon="F05D6", # alert-circle-outline ) - + @property def never_run(self) -> bool: return True - + def __init__(self, title="Error Handler", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("fn", socket_type="function") - - async def catch(self, state:GraphState, exc:Exception): - - + + async def catch(self, state: GraphState, exc: Exception): log.info("Error caught", error=exc) fn_socket = self.get_input_socket("fn") - + fn_node = fn_socket.source.node - + fn = await fn_node.run(state) - + if not isinstance(fn, FunctionWrapper): log.error(f"fn must be a FunctionWrapper instance, got {fn} instead") return False - + exc_wrapper = ExceptionWrapper( name=exc.__class__.__name__, message=str(exc), ) - + caught = await fn(exc=exc_wrapper) - + log.debug("Error handler result", result=caught) - + return caught - + @register("core/functions/UnpackException") class UnpackException(Node): """ Unpacks an ExceptionWrapper instance into an description and message """ - + def __init__(self, title="Unpack Exception", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("exc", socket_type="exception") self.add_output("name") self.add_output("message") - - async def run(self, state:GraphState): + + async def run(self, state: GraphState): exc = self.get_input_value("exc") - + if not isinstance(exc, ExceptionWrapper): - log.error("Expected ExceptionWrapper instance, got %s instead", type(exc).__name__) + log.error( + "Expected ExceptionWrapper instance, got %s instead", type(exc).__name__ + ) return - - self.set_output_values({ - "name": exc.name, - "message": exc.message, - }) - \ No newline at end of file + + self.set_output_values( + { + "name": exc.name, + "message": exc.message, + } + ) diff --git a/src/talemate/game/engine/nodes/scene.py b/src/talemate/game/engine/nodes/scene.py index a1aa5eeb..9693440f 100644 --- a/src/talemate/game/engine/nodes/scene.py +++ b/src/talemate/game/engine/nodes/scene.py @@ -2,23 +2,21 @@ import structlog from typing import TYPE_CHECKING, ClassVar from .core import ( Loop, - Node, - Entry, - GraphState, - UNRESOLVED, - LoopBreak, - LoopContinue, - NodeVerbosity, - InputValueError, + Node, + GraphState, + UNRESOLVED, + LoopBreak, + LoopContinue, + NodeVerbosity, + InputValueError, PropertyField, Trigger, ) import dataclasses from .registry import register, get_nodes_by_base_type, get_node from .event import connect_listeners, disconnect_listeners -from .command import Command import talemate.events as events -from talemate.emit import wait_for_input, emit +from talemate.emit import wait_for_input from talemate.exceptions import ActedAsCharacter, AbortWaitForInput, GenerationCancelled from talemate.context import active_scene, InteractionState from talemate.instance import get_agent @@ -41,58 +39,60 @@ async_signals.register( "scene_loop_init_after", ) + @dataclasses.dataclass class SceneLoopEvent(events.Event): scene: "Scene" event_type: str + @register("scene/GetSceneState") class GetSceneState(Node): - """ Gets some basic information about the scene - + Outputs: - + - characters: A list of characters in the scene - active: Whether the scene is active - auto_save: Whether auto save is enabled - auto_progress: Whether auto progress is enabled - scene: The scene instance """ - + def __init__(self, title="Get Scene State", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): - # scene state self.add_output("characters", socket_type="list") - + # scene settings self.add_output("active", socket_type="bool") self.add_output("auto_save", socket_type="bool") self.add_output("auto_progress", socket_type="bool") self.add_output("scene", socket_type="scene") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() - self.set_output_values({ - "characters": scene.characters, - "active": scene.active, - "auto_save": scene.auto_save, - "auto_progress": scene.auto_progress, - "scene": scene - }) + scene: "Scene" = active_scene.get() + self.set_output_values( + { + "characters": scene.characters, + "active": scene.active, + "auto_save": scene.auto_save, + "auto_progress": scene.auto_progress, + "scene": scene, + } + ) @register("scene/MakeCharacter") class MakeCharacter(Node): """ Make a character - + Inputs: - + - name: The name of the character - description: The description of the character - color: The color of the character name @@ -100,9 +100,9 @@ class MakeCharacter(Node): - is_player: Whether the character is the player character - add_to_scene: Whether to add the character to the scene - is_active: Whether the character is active - + Properties: - + - name: The name of the character - description: The description of the character - color: The color of the character name @@ -110,66 +110,66 @@ class MakeCharacter(Node): - is_player: Whether the character is the player character - add_to_scene: Whether to add the character to the scene - is_active: Whether the character is active - + Outputs: - + - character: The character object - actor: The actor object """ - + class Fields: name = PropertyField( name="name", description="The name of the character", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + description = PropertyField( name="description", description="The description of the character", type="text", - default="" + default="", ) - + color = PropertyField( name="color", description="The color of the character name", type="color", - default=UNRESOLVED + default=UNRESOLVED, ) - + base_attributes = PropertyField( name="base_attributes", description="The base attributes of the character", type="dict", - default=UNRESOLVED + default=UNRESOLVED, ) - + is_player = PropertyField( name="is_player", description="Whether the character is the player character", type="bool", - default=False + default=False, ) - + add_to_scene = PropertyField( name="add_to_scene", description="Whether to add the character to the scene", type="bool", - default=True + default=True, ) - + is_active = PropertyField( name="is_active", description="Whether the character is active", type="bool", - default=True + default=True, ) - + def __init__(self, title="Make Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("name", socket_type="str") self.add_input("description", socket_type="text", optional=True) @@ -178,7 +178,7 @@ class MakeCharacter(Node): self.add_input("is_player", socket_type="bool", optional=True) self.add_input("add_to_scene", socket_type="bool", optional=True) self.add_input("is_active", socket_type="bool", optional=True) - + self.set_property("name", UNRESOLVED) self.set_property("description", "") self.set_property("color", UNRESOLVED) @@ -186,12 +186,12 @@ class MakeCharacter(Node): self.set_property("is_player", False) self.set_property("add_to_scene", True) self.set_property("is_active", True) - + self.add_output("character", socket_type="character") self.add_output("actor", socket_type="actor") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() name = self.require_input("name") description = self.normalized_input_value("description") color = self.normalized_input_value("color") @@ -199,183 +199,180 @@ class MakeCharacter(Node): is_player = self.normalized_input_value("is_player") add_to_scene = self.normalized_input_value("add_to_scene") is_active = self.normalized_input_value("is_active") - + if not color: color = random_color() - + character = scene.Character( name=name, description=description, color=color, base_attributes=base_attributes, - is_player=is_player + is_player=is_player, ) - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Make character", character=character) - - self.set_output_values({ - "character": character - }) - + + self.set_output_values({"character": character}) + if is_player: ActorCls = scene.Player else: ActorCls = scene.Actor - + actor = ActorCls(character, get_agent("conversation")) - + if add_to_scene: await scene.add_actor(actor) if not is_active: await deactivate_character(character) - - self.set_output_values({ - "actor": actor, - "character": character - }) + + self.set_output_values({"actor": actor, "character": character}) + @register("scene/GetCharacter") class GetCharacter(Node): - """ Returns a character object from the scene by name - + Inputs: - + - character_name: The name of the character - + Outputs: - + - character: The character object """ - + def __init__(self, title="Get Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("character_name", socket_type="str") self.add_output("character", socket_type="character") - + async def run(self, state: GraphState): character_name = self.get_input_value("character_name") - scene:"Scene" = active_scene.get() - + scene: "Scene" = active_scene.get() + character = scene.get_character(character_name) - - self.set_output_values({ - "character": character - }) + + self.set_output_values({"character": character}) + @register("scene/IsPlayerCharacter") class IsPlayerCharacter(Node): - """ Returns whether a character is the player character - + Inputs: - + - character: The character object - + Outputs: - + - yes: True if the character is the player character - no: True if the character is not the player character - character: The character object """ - + def __init__(self, title="Is Player Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("character", socket_type="character") self.add_output("yes", socket_type="bool") self.add_output("no", socket_type="bool") self.add_output("character", socket_type="character") - - + async def run(self, state: GraphState): - character:"Character" = self.get_input_value("character") + character: "Character" = self.get_input_value("character") self.outputs[0].deactivated = not character.is_player self.outputs[1].deactivated = character.is_player - + if state.verbosity >= NodeVerbosity.VERBOSE: - log.debug("Is player character", character=character, is_player=character.is_player) - - self.set_output_values({ - "yes": True if character.is_player else UNRESOLVED, - "no": True if not character.is_player else UNRESOLVED, - "character": character - }) - + log.debug( + "Is player character", + character=character, + is_player=character.is_player, + ) + + self.set_output_values( + { + "yes": True if character.is_player else UNRESOLVED, + "no": True if not character.is_player else UNRESOLVED, + "character": character, + } + ) + + @register("scene/GetPlayerCharacter") class GetPlayerCharacter(Node): - """ Get the main player character from the scene - + Outputs: - + - character: The player character """ - + def __init__(self, title="Get Player Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_output("character", socket_type="character") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() character = scene.get_player_character() - - self.set_output_values({ - "character": character - }) + + self.set_output_values({"character": character}) + @register("scene/UpdateCharacterData") class UpdateCharacterData(Node): """ Update the data of a character - + Inputs: - + - character: The character object - base_attributes: The base attributes dictionary - details: The details dictionary - description: The description string - color: The color of the character name - + Outputs: - + - character: The updated character object """ - + class Fields: description = PropertyField( name="description", description="The character description", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + name = PropertyField( name="name", description="The character name", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + color = PropertyField( name="color", description="The color of the character name", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) def __init__(self, title="Update Character Data", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("character", socket_type="character") self.add_input("base_attributes", socket_type="dict", optional=True) @@ -387,15 +384,15 @@ class UpdateCharacterData(Node): self.set_property("name", UNRESOLVED) self.set_property("color", UNRESOLVED) self.add_output("character", socket_type="character") - + async def run(self, state: GraphState): - character:"Character" = self.get_input_value("character") + character: "Character" = self.get_input_value("character") color = self.get_input_value("color") base_attributes = self.get_input_value("base_attributes") details = self.get_input_value("details") description = self.get_input_value("description") name = self.get_input_value("name") - + if self.is_set(base_attributes): character.update(base_attributes=base_attributes) if self.is_set(details): @@ -406,74 +403,74 @@ class UpdateCharacterData(Node): character.rename(name) if self.is_set(color): character.color = color - - self.set_output_values({ - "character": character - }) - - - + + self.set_output_values({"character": character}) + @register("scene/UnpackInteractionState") class UnpackInteractionState(Node): """ Will take an interaction state and unpack it into the individual fields - + Inputs - interaction_state `interaction_state` - + Outputs - act_as `str` - from_choice `str` - input `str` - reset_requested `bool` """ - + def __init__(self, title="Unpack Interaction State", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("interaction_state", socket_type="interaction_state") self.add_output("act_as", socket_type="str") self.add_output("from_choice", socket_type="str") self.add_output("input", socket_type="str") self.add_output("reset_requested", socket_type="bool") - + async def run(self, state: GraphState): - interaction_state:InteractionState = self.get_input_value("interaction_state") - + interaction_state: InteractionState = self.get_input_value("interaction_state") + if not isinstance(interaction_state, InteractionState): - raise InputValueError(self, "interaction_state", "Input is not an InteractionState instance") - - self.set_output_values({ - "act_as": interaction_state.act_as, - "from_choice": interaction_state.from_choice, - "input": interaction_state.input, - "reset_requested": interaction_state.reset_requested - }) - - + raise InputValueError( + self, "interaction_state", "Input is not an InteractionState instance" + ) + + self.set_output_values( + { + "act_as": interaction_state.act_as, + "from_choice": interaction_state.from_choice, + "input": interaction_state.input, + "reset_requested": interaction_state.reset_requested, + } + ) + + @register("scene/message/CharacterMessage") class CharacterMessage(Node): """ Creates a character message from a character and a message - + Inputs: - + - character: The character object - message: The message to send - source: The source of the message - player or ai, so whether the message is result of user input or AI generated - from_choice: For player messages this indicates that the message was generated from a choice selection, for ai sourced messages this indicates the instruction that was followed - + Properties: - + - source: The source of the message - + Outputs: - + - message: The message object (this is a scene_message.CharacterMessage instance) """ - + class Fields: source = PropertyField( name="source", @@ -483,65 +480,62 @@ class CharacterMessage(Node): choices=[ "player", "ai", - ] + ], ) - + def __init__(self, title="Character Message", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("character", socket_type="character") self.add_input("message", socket_type="str") self.add_input("source", socket_type="str", optional=True) self.add_input("from_choice", socket_type="str", optional=True) - + self.set_property("source", "player") - + self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): - character:"Character" = self.get_input_value("character") + character: "Character" = self.get_input_value("character") message = self.get_input_value("message") source = self.get_input_value("source") from_choice = self.get_input_value("from_choice") - + extra = {} - + if isinstance(from_choice, str): extra["from_choice"] = from_choice - + # prefix name: if not already prefixed if not message.startswith(f"{character.name}: "): message = f"{character.name}: {message}" - - message = scene_message.CharacterMessage( - message, source=source, **extra - ) - - self.set_output_values({ - "message": message - }) + + message = scene_message.CharacterMessage(message, source=source, **extra) + + self.set_output_values({"message": message}) + @register("scene/message/NarratorMessage") class NarratorMessage(Node): """ Creates a narrator message - + Inputs: - + - message: The message to send - source: The source of the message - player or ai, so whether the message is result of user input or AI generated - meta: A dictionary of meta information to attach to the message. This will generally be arguments and function name that was called on the narrator agent to generate the message and will be used when regenerating the message. - + Properties: - + - source: The source of the message - + Outputs: - + - message: The message object (this is a scene_message.NarratorMessage instance) """ - + class Fields: source = PropertyField( name="source", @@ -551,61 +545,58 @@ class NarratorMessage(Node): choices=[ "player", "ai", - ] + ], ) - + def __init__(self, title="Narrator Message", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("message", socket_type="str") self.add_input("source", socket_type="str", optional=True) self.add_input("meta", socket_type="dict", optional=True) - + self.set_property("source", "ai") - + self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): message = self.get_input_value("message") source = self.get_input_value("source") meta = self.get_input_value("meta") - + extra = {} - + if meta and isinstance(meta, dict): extra["meta"] = meta - - message = scene_message.NarratorMessage( - message, source=source, **extra - ) - - self.set_output_values({ - "message": message - }) + + message = scene_message.NarratorMessage(message, source=source, **extra) + + self.set_output_values({"message": message}) + @register("scene/message/DirectorMessage") class DirectorMessage(Node): """ Creates a director message - + Inputs: - + - message: The message to send - source: The source of the message - player or ai, so whether the message is result of user input or AI generated - meta: A dictionary of meta information to attach to the message. Can hold the character name that the message is related to. - character: The character object that the message is related to - + Properties: - + - source: The source of the message - action: Describes the director action - + Outputs: - + - message: The message object (this is a scene_message.DirectorMessage instance) """ - + class Fields: source = PropertyField( name="source", @@ -615,281 +606,288 @@ class DirectorMessage(Node): choices=[ "player", "ai", - ] + ], ) - + action = PropertyField( name="action", description="Describes the director action", type="str", - default="actor_instruction" + default="actor_instruction", ) - + def __init__(self, title="Director Message", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("message", socket_type="str") self.add_input("source", socket_type="str", optional=True) self.add_input("meta", socket_type="dict", optional=True) self.add_input("character", socket_type="character", optional=True) self.add_input("action", socket_type="str", optional=True) - + self.set_property("source", "ai") self.set_property("action", "actor_instruction") - + self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): message = self.get_input_value("message") source = self.get_input_value("source") action = self.get_input_value("action") meta = self.get_input_value("meta") - character:"Character" = self.get_input_value("character") - + character: "Character" = self.get_input_value("character") + extra = {} - + if meta and isinstance(meta, dict): extra["meta"] = meta - + message = scene_message.DirectorMessage( message, source=source, action=action, **extra ) - + if character and character is not UNRESOLVED: message.set_meta(character=character.name) - - self.set_output_values({ - "message": message - }) + + self.set_output_values({"message": message}) + @register("scene/message/UnpackMeta") class UnpackMessageMeta(Node): """ Unpacks a message meta dictionary into arguments - + Inputs: - + - meta: The meta dictionary - + Outputs: - + - agent_name: The agent name - function_name: The function name - arguments: The arguments dictionary """ - + def __init__(self, title="Unpack Message Meta", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("meta", socket_type="dict") self.add_output("agent_name", socket_type="str") self.add_output("function_name", socket_type="str") self.add_output("arguments", socket_type="dict") - + async def run(self, state: GraphState): meta = self.get_input_value("meta") - - self.set_output_values({ - "agent_name": meta["agent"], - "function_name": meta["function"], - "arguments": meta.get("arguments", {}).copy() - }) + + self.set_output_values( + { + "agent_name": meta["agent"], + "function_name": meta["function"], + "arguments": meta.get("arguments", {}).copy(), + } + ) @register("scene/message/ToggleMessageContextVisibility") class ToggleMessageContextVisibility(Node): """ Hide or show a message. Hidden messages are not displayed to the AI. - + Inputs: - + - message: The message object - + Properties: - + - hidden: Whether the message is hidden - + Outputs: - + - message: The message object """ - + class Fields: hidden = PropertyField( name="hidden", description="Whether the message is hidden", type="bool", - default=False + default=False, ) - + def __init__(self, title="Toggle Message Context Visibility", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("message", socket_type="message_object") - + self.set_property("hidden", False) - + self.add_output("message", socket_type="message_object") - + async def run(self, state: GraphState): message = self.require_input("message") hidden = self.get_property("hidden") - + if hidden: message.hide() else: message.show() - - self.set_output_values({ - "message": message - }) - + + self.set_output_values({"message": message}) + @register("input/WaitForInput") class WaitForInput(Node): """ Get input from the user to interact with the scene. - + This node will wait for the user to input a message, and then return the message for processing. - + Inputs: - + - state: The current graph state - player_character: The player character - reason: The reason for the input - prefix: The prefix for the input message (similar to a cli prompt) - abort_condition: A condition to abort the input loop - + Properties - + - allow_commands: Allow commands to be executed, using the ! prefix - + Outputs: - + - input: The input message - interaction_state: The interaction state - character: The character object - + Abort Conditions: - + The chain of nodes connected to the abort_condition socket will be executed on each iteration of the input loop. If the chain resolves to a boolean value, the input loop will be aborted. - + You can use this to check for conditions that should abort the input loop. """ - + class Fields: allow_commands = PropertyField( name="allow_commands", description="Allow commands to be executed, using the ! prefix", type="bool", - default=True + default=True, ) prefix = PropertyField( name="prefix", description="The prefix for the input message (similar to a cli prompt)", type="str", - default="" + default="", ) reason = PropertyField( name="reason", description="The reason for the input", type="str", - default="talk" + default="talk", ) - + def __init__(self, title="Get Input", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("player_character", optional=True, socket_type="character") self.add_input("reason", optional=True, socket_type="str") self.add_input("prefix", optional=True, socket_type="str") self.add_input("abort_condition", optional=True, socket_type="any") - + self.set_property("reason", "talk") self.set_property("prefix", "") self.set_property("allow_commands", True) - + self.add_output("input", socket_type="str") self.add_output("interaction_state", socket_type="interaction_state") self.add_output("character", socket_type="character") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() - player_character:"Character" = self.get_input_value("player_character") + scene: "Scene" = active_scene.get() + player_character: "Character" = self.get_input_value("player_character") allow_commands = self.get_property("allow_commands") - + async def _abort_condition() -> bool: """ Logic that checks on whether the node connected to the abort_condition socket - resolves. - + resolves. + Once it does resolve, the input loop will be aborted, and an AbortWaitForInput exception will be propagated to a LoopContinue exception. """ - + socket = self.get_input_socket("abort_condition") - + # nothing connected, so return False if not socket.source: return False - + # run a new state of the graph to get the value of the connected node # this only runs the node connected to the abort_condition socket and any # ascending nodes it depends on - inner_state:GraphState = await state.graph.execute_to_node(socket.source.node, state) - + inner_state: GraphState = await state.graph.execute_to_node( + socket.source.node, state + ) + # get the value of the connected node - rv = inner_state.get_node_socket_value(socket.source.node, socket.source.name) - + rv = inner_state.get_node_socket_value( + socket.source.node, socket.source.name + ) + # if the value is a boolean, return it as is if isinstance(rv, bool): return rv - + # if the value is not None and not UNRESOLVED, return True (abort) - return (rv is not None and rv != UNRESOLVED) - + return rv is not None and rv != UNRESOLVED + # prepare the kwargs for wait_for_input wait_for_input_kwargs = { - "abort_condition": _abort_condition if self.get_input_socket("abort_condition").source else None, + "abort_condition": _abort_condition + if self.get_input_socket("abort_condition").source + else None, } - + # if the verbosity is verbose, set the sleep time to 1 so that the input loop # doesn't spam the console if state.verbosity == NodeVerbosity.VERBOSE: wait_for_input_kwargs["sleep_time"] = 1 - - - is_game_loop = ( - not state.shared.get("creative_mode") or - state.shared.get("nested_scene_loop", False) + + is_game_loop = not state.shared.get("creative_mode") or state.shared.get( + "nested_scene_loop", False ) - + try: if player_character and is_game_loop: - await async_signals.get("player_turn_start").send(events.PlayerTurnStartEvent( - scene=scene, - event_type="player_turn_start", - )) - + await async_signals.get("player_turn_start").send( + events.PlayerTurnStartEvent( + scene=scene, + event_type="player_turn_start", + ) + ) + input = await wait_for_input( self.get_input_value("prefix"), - character=player_character if player_character is not UNRESOLVED else None, + character=player_character + if player_character is not UNRESOLVED + else None, data={"reason": self.get_property("reason")}, return_struct=True, - **wait_for_input_kwargs + **wait_for_input_kwargs, ) except AbortWaitForInput: raise LoopContinue() - + text_message = input["message"] interaction_state = input["interaction"] @@ -898,34 +896,41 @@ class WaitForInput(Node): if not text_message: # input was empty, so continue the loop raise LoopContinue() - + if allow_commands and text_message.startswith("!"): command_state = {} node_cmd_executed = False - await scene.commands.execute(text_message, emit_on_unknown=False, state=command_state) + await scene.commands.execute( + text_message, emit_on_unknown=False, state=command_state + ) # no talemate command was executed, see if a matching node command exists - + talemate_cmd_executed = command_state.get("_commands_executed") - + if not talemate_cmd_executed: node_cmd_executed = await self.execute_node_command(state, text_message) - + if not node_cmd_executed and not talemate_cmd_executed: scene.commands.system_message(f"Unknown command: {text_message}") state.shared["signal_game_loop"] = False state.shared["skip_to_player"] = True raise LoopBreak() - - - log.debug("Wait for input", text_message=text_message, interaction_state=interaction_state) - - self.set_output_values({ - "input": text_message, - "interaction_state": interaction_state, - "character": player_character, - }) - async def execute_node_command(self, state:GraphState, command_name:str) -> bool: + log.debug( + "Wait for input", + text_message=text_message, + interaction_state=interaction_state, + ) + + self.set_output_values( + { + "input": text_message, + "interaction_state": interaction_state, + "character": player_character, + } + ) + + async def execute_node_command(self, state: GraphState, command_name: str) -> bool: """ Get a command node from the scene """ @@ -935,42 +940,46 @@ class WaitForInput(Node): except ValueError: command_name = command_name.strip() arg_str = "" - + args = arg_str.split(";", 1) - + # remove leading and trailing spaces from the command name command_name = command_name.strip() - + # remove ! from the command name command_name = command_name.lstrip("!") - + # get the command node from the scene - registry_name:str | None = state.data["_commands"].get(command_name) - + registry_name: str | None = state.data["_commands"].get(command_name) + if not registry_name: return False - + # turn args into dict with arg_{N} keys args_dict = {f"arg_{i}": arg for i, arg in enumerate(args)} - + command_node = get_node(registry_name) if not command_node: - log.error("Command node not found", command_name=command_name, registry_name=registry_name) + log.error( + "Command node not found", + command_name=command_name, + registry_name=registry_name, + ) return False - + await command_node().execute_command(state, **args_dict) return True + @register("scene/event/trigger/GameLoopActorIter") class TriggerGameLoopActorIter(Trigger): - """ Trigger the game loop actor iteration event. - + In a most basic setup you will trigger this everytime an actor has had a turn. - + Inputs: - + - actor: The actor that has had a turn """ @@ -983,50 +992,50 @@ class TriggerGameLoopActorIter(Trigger): def __init__(self, title="Game Loop Actor Iteration", **kwargs): super().__init__(title=title, **kwargs) - - def make_event_object(self, state:GraphState) -> events.GameLoopActorIterEvent: + + def make_event_object(self, state: GraphState) -> events.GameLoopActorIterEvent: return events.GameLoopActorIterEvent( scene=active_scene.get(), event_type="game_loop_actor_iter", actor=self.get_input_value("actor"), - game_loop=state.shared["game_loop"] + game_loop=state.shared["game_loop"], ) - + def setup_required_inputs(self): super().setup_required_inputs() self.add_input("actor", socket_type="actor") - + def setup_optional_inputs(self): return - + def setup_properties(self): return - - async def after(self, state:GraphState, event:events.GameLoopActorIterEvent): - + + async def after(self, state: GraphState, event: events.GameLoopActorIterEvent): new_event = events.GameLoopCharacterIterEvent( scene=active_scene.get(), event_type="game_loop_player_character_iter", character=event.actor.character, - game_loop=state.shared["game_loop"] + game_loop=state.shared["game_loop"], ) - + if event.actor.character.is_player: await self.signals.get("game_loop_player_character_iter").send(new_event) else: await self.signals.get("game_loop_ai_character_iter").send(new_event) + @register("scene/UnpackCharacter") class UnpackCharacter(Node): """ Unpack a character into its individual fields - + Inputs: - + - character: The character object - + Outputs: - + - name: The name of the character - is_player: Whether the character is the player character - description: The character description @@ -1035,13 +1044,13 @@ class UnpackCharacter(Node): - color: The character name color - actor: The actor instance tied to the character """ - + def __init__(self, title="Unpack Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("character", socket_type="character") - + self.add_output("name", socket_type="str") self.add_output("is_player", socket_type="bool") self.add_output("description", socket_type="str") @@ -1049,321 +1058,319 @@ class UnpackCharacter(Node): self.add_output("details", socket_type="dict") self.add_output("color", socket_type="str") self.add_output("actor", socket_type="actor") - + async def run(self, state: GraphState): - character:"Character" = self.get_input_value("character") - - self.set_output_values({ - "name": character.name, - "is_player": character.is_player, - "actor": character.actor, - "description": character.description, - "base_attributes": character.base_attributes, - "details": character.details, - "color": character.color - }) + character: "Character" = self.get_input_value("character") + + self.set_output_values( + { + "name": character.name, + "is_player": character.is_player, + "actor": character.actor, + "description": character.description, + "base_attributes": character.base_attributes, + "details": character.details, + "color": character.color, + } + ) + @register("scene/ActivateCharacter") class ActivateCharacter(Node): """ Activate a character - + Inputs: - + - character: The character to activate - + Outputs: - + - character: The activated character """ - + def __init__(self, title="Activate Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("character", socket_type="character") self.add_output("character", socket_type="character") - + async def run(self, state: GraphState): - character:"Character" = self.get_input_value("character") - + character: "Character" = self.get_input_value("character") + await activate_character(active_scene.get(), character) - - self.set_output_values({ - "character": character - }) - + + self.set_output_values({"character": character}) + + @register("scene/DeactivateCharacter") class DeactivateCharacter(Node): """ Deactivate a character - + Inputs: - + - character: The character to deactivate - + Outputs: - + - character: The deactivated character """ - + def __init__(self, title="Deactivate Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("character", socket_type="character") self.add_output("character", socket_type="character") - + async def run(self, state: GraphState): - character:"Character" = self.get_input_value("character") - + character: "Character" = self.get_input_value("character") + await deactivate_character(active_scene.get(), character) - - self.set_output_values({ - "character": character - }) - + + self.set_output_values({"character": character}) + + @register("scene/RemoveAllCharacters") class RemoveAllCharacters(Node): """ Remove all characters from the scene - + Inputs: - + - state: The graph state - - Outputs: - + + Outputs: + - state: The graph state """ - + def __init__(self, title="Remove All Characters", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() - + scene: "Scene" = active_scene.get() + characters = list(scene.characters) - + for character in characters: await scene.remove_character(character) - - self.set_output_values({ - "state": state - }) - + + self.set_output_values({"state": state}) + + @register("scene/RemoveCharacter") class RemoveCharacter(Node): """ Remove a character from the scene - + Inputs: - + - state: The graph state - character: The character to remove - + Outputs: - + - state: The graph state """ - + def __init__(self, title="Remove Character", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("character", socket_type="character") self.add_output("state") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() - character:"Character" = self.get_input_value("character") - + scene: "Scene" = active_scene.get() + character: "Character" = self.get_input_value("character") + await scene.remove_character(character) - - self.set_output_values({ - "state": state - }) - - + + self.set_output_values({"state": state}) + + # get the current scene loop state @register("scene/GetSceneLoopState") class GetSceneLoopState(Node): - """ Returns the current scene loop states - + Outputs: - + - state: The current node state, this is the state of the graph currently being processed - parent: The parent node state, this is the state of the graph that contains the current graph - shared: The shared state, this is the state shared between all graphs """ - + def __init__(self, title="Get Scene Loop State", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_output("state", socket_type="dict") self.add_output("parent", socket_type="dict") self.add_output("shared", socket_type="dict") - + async def run(self, state: GraphState): - self.set_output_values({ - "state": state.data, - "parent": state.outer.data if getattr(state, "outer", None) else {}, - "shared": state.shared - }) + self.set_output_values( + { + "state": state.data, + "parent": state.outer.data if getattr(state, "outer", None) else {}, + "shared": state.shared, + } + ) + @register("scene/Restore") class RestoreScene(Node): """ Restore the scene to its resore point - + Inputs: - + - state: The graph state - + Outputs: - + - state: The graph state """ - + def __init__(self, title="Restore Scene", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() await scene.restore() - - self.set_output_values({ - "state": state - }) + + self.set_output_values({"state": state}) + @register("scene/SetIntroduction") class SetIntroduction(Node): """ Set the introduction text for the scene - + Inputs: - + - state: The graph state - introduction: The introduction text - + Properties: - + - introduction: The introduction text - emit_history: Whether to re-emit the entire history of the scene - + Outputs: - + - state: The graph state """ - + class Fields: introduction = PropertyField( name="introduction", description="The introduction text", type="text", - default=UNRESOLVED + default=UNRESOLVED, ) - + emit_history = PropertyField( name="emit_history", description="Whether to re-emit the entire history of the scene", type="bool", - default=True + default=True, ) - + def __init__(self, title="Set Introduction", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("introduction", socket_type="str") self.set_property("introduction", UNRESOLVED) self.set_property("emit_history", True) self.add_output("state") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() introduction = self.require_input("introduction") emit_history = self.get_input_value("emit_history") - + scene.set_intro(introduction) - + if emit_history: await scene.emit_history() - - self.set_output_values({ - "state": self.get_input_value("state") - }) - - - -@register("scene/SceneLoop", as_base_type=True) + + self.set_output_values({"state": self.get_input_value("state")}) + + +@register("scene/SceneLoop", as_base_type=True) class SceneLoop(Loop): - """ The main scene loop node - + It will loop through the scene graph until the loop is broken. - + Properties: - + - trigger_game_loop: Whether to trigger the game loop event """ - + class Fields: trigger_game_loop = PropertyField( name="trigger_game_loop", description="Trigger the game loop event", type="bool", - default=True + default=True, ) - + _export_definition: ClassVar[bool] = False - + @property def scene_loop_event(self) -> SceneLoopEvent: - return SceneLoopEvent( - scene=active_scene.get(), - event_type="scene_loop" - ) - + return SceneLoopEvent(scene=active_scene.get(), event_type="scene_loop") + def __init__(self, title="Scene Loop", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("trigger_game_loop", True) - + async def on_loop_start(self, state: GraphState): - scene:"Scene" = state.outer.data["scene"] + scene: "Scene" = state.outer.data["scene"] await scene.ensure_memory_db() await scene.load_active_pins() - + connect_listeners(self, state, disconnect=True) - + if not state.data.get("_scene_loop_init"): state.data["_scene_loop_init"] = True state.data["_commands"] = {} await self.register_commands(scene, state) await async_signals.get("scene_loop_init").send(self.scene_loop_event) await async_signals.get("scene_loop_init_after").send(self.scene_loop_event) - + trigger_game_loop = self.get_property("trigger_game_loop") - + if state.verbosity >= NodeVerbosity.VERBOSE: - log.debug("TRIGGER GAME LOOP", id=self.id, trigger_game_loop=trigger_game_loop, signal_game_loop=state.shared.get("signal_game_loop"), skip_to_player=state.shared.get("skip_to_player")) - + log.debug( + "TRIGGER GAME LOOP", + id=self.id, + trigger_game_loop=trigger_game_loop, + signal_game_loop=state.shared.get("signal_game_loop"), + skip_to_player=state.shared.get("skip_to_player"), + ) + if trigger_game_loop: game_loop = events.GameLoopEvent( scene=self, event_type="game_loop", had_passive_narration=False @@ -1378,22 +1385,20 @@ class SceneLoop(Loop): _iteration = 0 state.shared["signal_game_loop"] = True - state.shared["scene_loop"] = { - "_iteration": _iteration + 1 - } + state.shared["scene_loop"] = {"_iteration": _iteration + 1} state.shared["creative_mode"] = scene.environment == "creative" - + await async_signals.get("scene_loop_start_cycle").send(self.scene_loop_event) - + async def on_loop_end(self, state: GraphState): - scene:"Scene" = state.outer.data["scene"] + scene: "Scene" = state.outer.data["scene"] if scene.auto_save: await scene.save(auto=True) - + scene.emit_status() - + await async_signals.get("scene_loop_end_cycle").send(self.scene_loop_event) - + async def execute(self, outer_state: GraphState): """ Execute the scene loop @@ -1405,31 +1410,32 @@ class SceneLoop(Loop): disconnect_listeners(self, outer_state) async def on_loop_error(self, state: GraphState, exc: Exception): - scene:"Scene" = state.outer.data["scene"] + scene: "Scene" = state.outer.data["scene"] if isinstance(exc, ActedAsCharacter): state.shared["signal_game_loop"] = False state.shared["acted_as_character"] = scene.get_character(exc.character_name) raise LoopBreak() - + elif isinstance(exc, GenerationCancelled): state.shared["skip_to_player"] = True state.shared["signal_game_loop"] = False raise LoopBreak() - + await async_signals.get("scene_loop_error").send(self.scene_loop_event) - - async def register_commands(self, scene:"Scene", state:GraphState): + + async def register_commands(self, scene: "Scene", state: GraphState): """ Will check the scene._NODE_DEFINITIONS for any command/Command nodes and register them as commands in the scene. - + This is used to register commands that are defined in the scene nodes directory. """ - + for node_cls in get_nodes_by_base_type("command/Command"): - _node = node_cls() + _node = node_cls() command_name = _node.get_property("name") state.data["_commands"][command_name] = _node.registry - log.info(f"Registered command", command=f"!{command_name}", module=_node.registry) - + log.info( + "Registered command", command=f"!{command_name}", module=_node.registry + ) diff --git a/src/talemate/game/engine/nodes/scene_intent.py b/src/talemate/game/engine/nodes/scene_intent.py index 7ddccd6e..5b0fe92d 100644 --- a/src/talemate/game/engine/nodes/scene_intent.py +++ b/src/talemate/game/engine/nodes/scene_intent.py @@ -1,33 +1,20 @@ import structlog -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING from .core import ( - Loop, - Node, - Entry, - GraphState, - UNRESOLVED, - LoopBreak, - LoopContinue, - NodeVerbosity, - InputValueError, + Node, + GraphState, + UNRESOLVED, + InputValueError, PropertyField, - Trigger, - TYPE_CHOICES + TYPE_CHOICES, ) -import dataclasses from .registry import register -from .event import connect_listeners, disconnect_listeners -import talemate.events as events -from talemate.emit import wait_for_input, emit -from talemate.exceptions import ActedAsCharacter, AbortWaitForInput -from talemate.context import active_scene, InteractionState -import talemate.scene_message as scene_message -import talemate.emit.async_signals as async_signals -from talemate.scene.schema import SceneIntent, ScenePhase, SceneType +from talemate.context import active_scene +from talemate.scene.schema import ScenePhase, SceneType from talemate.scene.intent import set_scene_phase if TYPE_CHECKING: - from talemate.tale_mate import Scene, Character + from talemate.tale_mate import Scene __all__ = [ "GetSceneIntent", @@ -40,20 +27,22 @@ __all__ = [ log = structlog.get_logger("talemate.game.engine.nodes.scene_intent") -TYPE_CHOICES.extend([ - "scene_intent/scene_intent", - "scene_intent/scene_phase", - "scene_intent/scene_type", -]) +TYPE_CHOICES.extend( + [ + "scene_intent/scene_intent", + "scene_intent/scene_phase", + "scene_intent/scene_type", + ] +) + @register("scene/intention/GetSceneIntent") class GetSceneIntent(Node): - """ Returns the intent state. - + Outputs: - + - intent (str) - the overall intent - phase (scene_intent/scene_phase) - the current phase - scene_type (scene_intent/scene_type) - the current scene type @@ -62,45 +51,46 @@ class GetSceneIntent(Node): def __init__(self, title="Get Scene Intent", **kwargs): super().__init__(title=title, **kwargs) - - + def setup(self): - self.add_output("intent", socket_type="str") self.add_output("phase", socket_type="scene_intent/scene_phase") self.add_output("scene_type", socket_type="scene_intent/scene_type") self.add_output("start", socket_type="int") - + async def run(self, state: GraphState): - - scene:"Scene" = active_scene.get() - - self.set_output_values({ - "intent": scene.intent_state.intent, - }) - - phase:ScenePhase = scene.intent_state.phase - + scene: "Scene" = active_scene.get() + + self.set_output_values( + { + "intent": scene.intent_state.intent, + } + ) + + phase: ScenePhase = scene.intent_state.phase + if phase: - self.set_output_values({ - "phase": phase, - "scene_type": scene.intent_state.current_scene_type, - "start": scene.intent_state.start, - }) - + self.set_output_values( + { + "phase": phase, + "scene_type": scene.intent_state.current_scene_type, + "start": scene.intent_state.start, + } + ) + + @register("scene/intention/SetSceneIntent") class SetSceneIntent(Node): - """ Updates the overall intent. - + Inputs: - + - state - graph state - intent (str) - the overall intent - + Outputs: - + - state - graph state - intent (str) - the overall intent """ @@ -115,45 +105,44 @@ class SetSceneIntent(Node): def __init__(self, title="Set Scene Intent", **kwargs): super().__init__(title=title, **kwargs) - - + def setup(self): self.add_input("state") self.add_input("intent", socket_type="str", optional=True) - + self.set_property("intent", "") - + self.add_output("state") self.add_output("intent", socket_type="str") - - + async def run(self, state: GraphState): - - scene:"Scene" = active_scene.get() - + scene: "Scene" = active_scene.get() + intent = self.get_input_value("intent") - + scene.intent_state.intent = intent - - self.set_output_values({ - "state": scene.intent_state, - "intent": intent, - }) - + + self.set_output_values( + { + "state": scene.intent_state, + "intent": intent, + } + ) + + @register("scene/intention/SetScenePhase") class SetScenePhase(Node): - """ Set a new scene phase. - + Inputs: - + - state - graph state - scene_type (str) - the type of scene (scene type id) - intent (str) - the phase intent - + Outputs: - + - state - graph state - phase (scene_intent/scene_phase) - the new phase - scene_type (scene_intent/scene_type) - the scene type of the new phase (object) @@ -175,46 +164,46 @@ class SetScenePhase(Node): def __init__(self, title="Set Scene Phase", **kwargs): super().__init__(title=title, **kwargs) - - + def setup(self): self.add_input("state") self.add_input("scene_type", socket_type="str", optional=True) self.add_input("intent", socket_type="str", optional=True) - + self.set_property("scene_type", "") self.set_property("intent", "") - + self.add_output("state") self.add_output("phase", socket_type="scene_intent/scene_phase") self.add_output("scene_type", socket_type="scene_intent/scene_type") - + async def run(self, state: GraphState): - - scene:"Scene" = active_scene.get() - + scene: "Scene" = active_scene.get() + scene_type = self.get_input_value("scene_type") intent = self.get_input_value("intent") - + phase = await set_scene_phase(scene, scene_type, intent) - + scene.emit_status() - self.set_output_values({ - "state": scene.intent_state, - "phase": phase, - "scene_type": scene.intent_state.current_scene_type, - }) - + self.set_output_values( + { + "state": scene.intent_state, + "phase": phase, + "scene_type": scene.intent_state.current_scene_type, + } + ) + + @register("scene/intention/UnpackScenePhase") class UnpackScenePhase(Node): - """ Inputs: - + - phhase (scene_intent/scene_phase) - + Outputs - + - intent - scene_type - scene_type_instructions @@ -222,52 +211,52 @@ class UnpackScenePhase(Node): - scene_type_name - scene_type_id """ - + def __init__(self, title="Unpack Scene Phase", **kwargs): super().__init__(title=title, **kwargs) - - + def setup(self): self.add_input("phase", socket_type="scene_intent/scene_phase") - + self.add_output("intent", socket_type="str") self.add_output("scene_type", socket_type="str") self.add_output("scene_type_instructions", socket_type="str") self.add_output("scene_type_description", socket_type="str") self.add_output("scene_type_name", socket_type="str") self.add_output("scene_type_id", socket_type="str") - + async def run(self, state: GraphState): - - scene:"Scene" = active_scene.get() - phase:ScenePhase = self.get_input_value("phase") - - scene_type:SceneType = scene.intent_state.scene_types[phase.scene_type] - - self.set_output_values({ - "intent": phase.intent, - "scene_type": phase.scene_type, - "scene_type_instructions": scene_type.instructions, - "scene_type_description": scene_type.description, - "scene_type_name": scene_type.name, - "scene_type_id": scene_type.id, - }) - + scene: "Scene" = active_scene.get() + phase: ScenePhase = self.get_input_value("phase") + + scene_type: SceneType = scene.intent_state.scene_types[phase.scene_type] + + self.set_output_values( + { + "intent": phase.intent, + "scene_type": phase.scene_type, + "scene_type_instructions": scene_type.instructions, + "scene_type_description": scene_type.description, + "scene_type_name": scene_type.name, + "scene_type_id": scene_type.id, + } + ) + + @register("scene/intention/MakeSceneType") class MakeSceneType(Node): - """ Create a new scene type object. - + Inputs: - + - id (str) - scene type ID - name (str) - scene type name - description (text) - scene type description - instructions (text) - scene type instructions - + Outputs: - + - scene_type (scene_intent/scene_type) - the new scene type object """ @@ -305,150 +294,150 @@ class MakeSceneType(Node): def __init__(self, title="Make Scene Type", **kwargs): super().__init__(title=title, **kwargs) - - + def setup(self): - self.add_input("scene_type_id", socket_type="str", optional=True) self.add_input("name", socket_type="str", optional=True) self.add_input("description", socket_type="text", optional=True) self.add_input("instructions", socket_type="text", optional=True) - + self.set_property("scene_type_id", UNRESOLVED) self.set_property("name", "") self.set_property("description", "") self.set_property("instructions", "") self.set_property("auto_append", True) - - + self.add_output("scene_type", socket_type="scene_intent/scene_type") - + async def run(self, state: GraphState): - auto_append = self.normalized_input_value("auto_append") - scene:"Scene" = active_scene.get() - + scene: "Scene" = active_scene.get() + scene_type = SceneType( id=self.require_input("scene_type_id"), name=self.require_input("name"), description=self.normalized_input_value("description"), instructions=self.normalized_input_value("instructions"), ) - + if auto_append: scene.intent_state.scene_types[scene_type.id] = scene_type - - self.set_output_values({ - "scene_type": scene_type, - }) - + + self.set_output_values( + { + "scene_type": scene_type, + } + ) + + @register("scene/intention/GetSceneType") class GetSceneType(Node): - """ Get a scene type object. - + Inputs: - + - id (str) - scene type ID - + Outputs: - + - scene_type (scene_intent/scene_type) - the scene type object """ - + def __init__(self, title="Get Scene Type", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("scene_type_id", socket_type="str") self.add_output("scene_type", socket_type="scene_intent/scene_type") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() scene_type_id = self.require_input("scene_type_id") - + try: scene_type = scene.intent_state.scene_types[scene_type_id] self.set_output_values({"scene_type": scene_type}) except KeyError: - raise InputValueError(self, "scene_type_id", f"Scene type not found: {scene_type_id}") - + raise InputValueError( + self, "scene_type_id", f"Scene type not found: {scene_type_id}" + ) + + @register("scene/intention/UnpackSceneType") class UnpackSceneType(Node): - """ Unpack a scene type object. - + Inputs: - + - scene_type (scene_intent/scene_type) - the scene type object - + Outputs: - + - id (str) - scene type ID - name (str) - scene type name - description (text) - scene type description - instructions (text) - scene type instructions """ - + def __init__(self, title="Unpack Scene Type", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("scene_type", socket_type="scene_intent/scene_type") - + self.add_output("scene_type_id", socket_type="str") self.add_output("name", socket_type="str") self.add_output("description", socket_type="str") self.add_output("instructions", socket_type="str") - + async def run(self, state: GraphState): - scene_type:SceneType = self.get_input_value("scene_type") - - self.set_output_values({ - "scene_type_id": scene_type.id, - "name": scene_type.name, - "description": scene_type.description, - "instructions": scene_type.instructions, - }) - + scene_type: SceneType = self.get_input_value("scene_type") + + self.set_output_values( + { + "scene_type_id": scene_type.id, + "name": scene_type.name, + "description": scene_type.description, + "instructions": scene_type.instructions, + } + ) + + @register("scene/intention/RemoveSceneType") class RemoveSceneType(Node): - """ Remove a scene type object. - + Inputs: - + - state - graph state - id (str) - scene type ID - + Outputs: - + - state - graph state """ - + def __init__(self, title="Remove Scene Type", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("scene_type_id", socket_type="str") - + self.add_output("state") - + async def run(self, state: GraphState): - scene:"Scene" = active_scene.get() - + scene: "Scene" = active_scene.get() + scene_type_id = self.require_input("scene_type_id") - + scene.intent_state.scene_types.pop(scene_type_id, None) - - self.set_output_values({ - "state": scene.intent_state, - }) - - - - \ No newline at end of file + + self.set_output_values( + { + "state": scene.intent_state, + } + ) diff --git a/src/talemate/game/engine/nodes/state.py b/src/talemate/game/engine/nodes/state.py index 81d31b04..252db369 100644 --- a/src/talemate/game/engine/nodes/state.py +++ b/src/talemate/game/engine/nodes/state.py @@ -1,4 +1,4 @@ -from typing import ClassVar, TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any import structlog import pydantic from talemate.context import active_scene @@ -13,12 +13,13 @@ from talemate.game.engine.nodes.core import ( PropertyField, InputValueError, ) + if TYPE_CHECKING: from talemate.tale_mate import Scene log = structlog.get_logger("talemate.game.engine.nodes.state") -def coerce_to_type(value:Any, type_name:str): +def coerce_to_type(value: Any, type_name: str): if type_name == "str": return str(value) elif type_name == "number": @@ -30,41 +31,41 @@ def coerce_to_type(value:Any, type_name:str): else: raise ValueError(f"Cannot coerce value to type {type_name}") + class StateManipulation(Node): - """ Base class for state manipulation nodes """ - + class Fields: scope = PropertyField( name="scope", description="Which scope to manipulate", type="str", default="local", - choices=["local", "parent", "shared", "scene loop", "game"] + choices=["local", "parent", "shared", "scene loop", "game"], ) name = PropertyField( name="name", description="The name of the variable to manipulate", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) def setup(self): self.add_input("name", socket_type="str", optional=True) - self.set_property("name", UNRESOLVED) + self.set_property("name", UNRESOLVED) self.set_property("scope", "local") - + self.add_output("name", socket_type="str") self.add_output("value") self.add_output("scope", socket_type="str") def get_state_container(self, state: GraphState): scope = self.get_property("scope") - + if scope == "local": return state.data elif scope == "parent": @@ -78,94 +79,87 @@ class StateManipulation(Node): container = {} return container elif scope == "game": - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() return scene.game_state else: raise InputValueError(self, "scope", f"Unknown scope: {scope}") - - - + @register("state/SetState") class SetState(StateManipulation): """ Set a variable in the graph state - + Inputs: - + - name: the name to set - value: the value to set - scope: which scope to set the variable in - + Outputs: - + - value: the value that was set - name: the name that was set - scope: the scope that was set """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#2e4657", - icon="F01DA", # upload - auto_title="SET {scope}.{name}" + icon="F01DA", # upload + auto_title="SET {scope}.{name}", ) - + def __init__(self, title="Set State", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): super().setup() self.add_input("value") - + async def run(self, state: GraphState): name = self.require_input("name") value = self.require_input("value", none_is_set=True) scope = self.require_input("scope") - + if state.verbosity >= NodeVerbosity.VERBOSE: log.debug("Setting state variable", name=name, value=value, scope=scope) - - container = self.get_state_container(state) - - container[name] = value - - self.set_output_values({ - "name": name, - "value": value, - "scope": scope - }) - + container = self.get_state_container(state) + + container[name] = value + + self.set_output_values({"name": name, "value": value, "scope": scope}) + @register("state/GetState") class GetState(StateManipulation): """ Get a variable from the graph state - + Inputs: - + - name: the name to get - scope: which scope to get the variable from - + Outputs: - + - value: the value that was retrieved - name: the name that was retrieved - scope: the scope that was retrieved """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#44552f", - icon="F0552", # download - auto_title="GET {scope}.{name}" + icon="F0552", # download + auto_title="GET {scope}.{name}", ) - + def __init__(self, title="Get State", **kwargs): super().__init__(title=title, **kwargs) @@ -176,122 +170,112 @@ class GetState(StateManipulation): async def run(self, state: GraphState): name = self.require_input("name") scope = self.require_input("scope") - + default = self.get_input_value("default") - + if default is UNRESOLVED: default = None - + container = self.get_state_container(state) - + value = container.get(name, default) - - self.set_output_values({ - "name": name, - "value": value, - "scope": scope - }) - + + self.set_output_values({"name": name, "value": value, "scope": scope}) + + @register("state/UnsetState") class UnsetState(StateManipulation): """ Unset a variable in the graph state - + Inputs: - + - name: the name to unset - scope: which scope to unset the variable in - + Outputs: - + - name: the name that was unset - scope: the scope that was unset - value: the value that was unset """ - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#7f2e2e", - icon="F0683", # delete-circle - auto_title="UNSET {scope}.{name}" + icon="F0683", # delete-circle + auto_title="UNSET {scope}.{name}", ) - - + def __init__(self, title="Unset State", **kwargs): super().__init__(title=title, **kwargs) - + async def run(self, state: GraphState): name = self.require_input("name") scope = self.require_input("scope") - + container = self.get_state_container(state) - + value = container.pop(name, None) - - self.set_output_values({ - "name": name, - "value": value, - "scope": scope - }) - + + self.set_output_values({"name": name, "value": value, "scope": scope}) + + @register("state/HasState") class HasState(StateManipulation): """ Check if a variable exists in the graph state - + Inputs: - + - name: the name to check - scope: which scope to check the variable in - + Outputs: - + - name: the name that was checked - scope: the scope that was checked - exists: whether the variable exists (True) or not (False) """ - + def __init__(self, title="Has State", **kwargs): super().__init__(title=title, **kwargs) - + async def run(self, state: GraphState): name = self.require_input("name") scope = self.require_input("scope") - + container = self.get_state_container(state) - + exists = name in container - - self.set_output_values({ - "name": name, - "scope": scope, - "exists": exists - }) + + self.set_output_values({"name": name, "scope": scope, "exists": exists}) + @register("state/CounterState") class CounterState(StateManipulation): """ Counter node that increments a numeric value in the state and returns the new value. - + Inputs: - name: The key to the value to increment - scope: Which scope to use for the counter - reset: If true, the value will be reset to 0 - + Properties: - increment: The amount to increment the value by - name: The key to the value to increment - scope: Which scope to use for the counter - reset: If true, the value will be reset to 0 - + Outputs: - value: The new value - name: The key that was used - scope: The scope that was used """ - + class Fields(StateManipulation.Fields): increment = PropertyField( name="increment", @@ -299,122 +283,115 @@ class CounterState(StateManipulation): default=1, step=1, min=1, - description="The amount to increment the value by" + description="The amount to increment the value by", ) - + reset = PropertyField( name="reset", type="bool", default=False, - description="If true, the value will be reset to 0" + description="If true, the value will be reset to 0", ) - + @pydantic.computed_field(description="Node style") @property def style(self) -> NodeStyle: return NodeStyle( title_color="#2e4657", - icon="F0199", # counter - auto_title="COUNT {scope}.{name}" + icon="F0199", # counter + auto_title="COUNT {scope}.{name}", ) - + def __init__(self, title="State Counter", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") super().setup() self.add_input("reset", socket_type="bool", optional=True) - + self.set_property("increment", 1) self.set_property("reset", False) - + self.add_output("value") - + async def run(self, state: GraphState): name = self.require_input("name") scope = self.require_input("scope") reset = self.normalized_input_value("reset", bool) increment = self.get_input_value("increment") - + container = self.get_state_container(state) - + if reset: container[name] = 0 else: container[name] = container.get(name, 0) + increment - - self.set_output_values({ - "state": state, - "value": container[name], - "name": name, - "scope": scope - }) - - + + self.set_output_values( + {"state": state, "value": container[name], "name": name, "scope": scope} + ) + + @register("state/ConditionalSetState") class ConditionalSetState(SetState): """ Set a variable in the graph state - + Provides a required `state` input causing the node to only run when a state is provided """ - + def __init__(self, title="Set State (Conditional)", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") super().setup() - + async def run(self, state: GraphState): await super().run(state) - self.set_output_values({ - "state": self.get_input_value("state") - }) - + self.set_output_values({"state": self.get_input_value("state")}) + + @register("state/ConditionalUnsetState") class ConditionalUnsetState(UnsetState): """ Unset a variable in the graph state - + Provides a required `state` input causing the node to only run when a state is provided """ - + def __init__(self, title="Unset State (Conditional)", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") super().setup() - + async def run(self, state: GraphState): await super().run(state) - self.set_output_values({ - "state": self.get_input_value("state") - }) - + self.set_output_values({"state": self.get_input_value("state")}) + + @register("state/ConditionalCounterState") class ConditionalCounterState(CounterState): """ Counter node that increments a numeric value in the state and returns the new value. - + Provides a required `state` input causing the node to only run when a state is provided """ - + def __init__(self, title="Counter State (Conditional)", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("state") super().setup() - + async def run(self, state: GraphState): await super().run(state) - self.set_output_values({ - "state": self.get_input_value("state") - }) + self.set_output_values({"state": self.get_input_value("state")}) diff --git a/src/talemate/game/engine/nodes/string.py b/src/talemate/game/engine/nodes/string.py index 5adeb0aa..bf5193c3 100644 --- a/src/talemate/game/engine/nodes/string.py +++ b/src/talemate/game/engine/nodes/string.py @@ -1,290 +1,296 @@ -import re import structlog -from .core import Node, GraphState, UNRESOLVED, PropertyField, InputValueError +from .core import Node, GraphState, PropertyField, InputValueError from .registry import register log = structlog.get_logger("talemate.game.engine.nodes.string") + @register("data/string/Make") class MakeString(Node): """Creates a string - + Creates a string with the specified value. - + Properties: - + - value: The string value to create - + Outputs: - + - value: The created string value """ - + class Fields: value = PropertyField( name="value", description="The string value to create", type="str", - default="" + default="", ) - + def __init__(self, title="Make String", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.set_property("value", "") self.add_output("value", socket_type="str") - + async def run(self, state: GraphState): value = self.get_property("value") self.set_output_values({"value": value}) + @register("data/string/MakeText") class MakeText(MakeString): """ Same as make string but will be rendered with a multiline text editor - + Properties: - + - value: The string value to create - + Outputs: - + - value: The created string value """ + class Fields: value = PropertyField( name="value", description="The string value to create", type="text", - default="" + default="", ) - + def __init__(self, title="Make Text", **kwargs): super().__init__(title=title, **kwargs) - + @register("data/string/Split") class Split(Node): """Splits a string into a list based on a delimiter - + Divides a string into multiple parts using a specified delimiter. - + Inputs: - + - string: The string to split - delimiter: Character(s) to use as the split point (optional) - + Properties: - + - delimiter: Default delimiter to use when not provided via input - max_splits: Maximum number of splits to perform (-1 for all possible splits) - + Outputs: - + - parts: List of string parts after splitting """ - + class Fields: delimiter = PropertyField( name="delimiter", description="Character(s) to use as the split point", type="str", - default=" " + default=" ", ) max_splits = PropertyField( name="max_splits", description="Maximum number of splits to perform (-1 for all possible splits)", type="int", - default=-1 + default=-1, ) - + def setup(self): self.add_input("string", socket_type="str") self.add_input("delimiter", socket_type="str", optional=True) self.add_output("parts", socket_type="list") - + self.set_property("delimiter", " ") self.set_property("max_splits", -1) - + async def run(self, state: GraphState): - string:str = self.get_input_value("string") + string: str = self.get_input_value("string") delimiter = self.get_input_value("delimiter") max_splits = self.get_property("max_splits") - + # handle escaped newline delimiter if delimiter == "\\n": delimiter = "\n" - + parts = string.split(delimiter, maxsplit=max_splits) self.set_output_values({"parts": parts}) + @register("data/string/Join") class Join(Node): """Joins a list of strings with a delimiter - + Combines a list of strings into a single string with a specified delimiter between each element. - + Inputs: - + - strings: List of strings to join - delimiter: Character(s) to insert between each string (optional) - + Properties: - + - delimiter: Default delimiter to use when not provided via input - + Outputs: - + - result: The joined string """ - + class Fields: delimiter = PropertyField( name="delimiter", description="Character(s) to insert between each string", type="str", - default=" " + default=" ", ) - + def setup(self): self.add_input("strings", socket_type="list") self.add_input("delimiter", socket_type="str", optional=True) self.add_output("result", socket_type="str") - + self.set_property("delimiter", " ") - + async def run(self, state: GraphState): strings = self.get_input_value("strings") delimiter = self.get_input_value("delimiter") - + # handle escaped newline delimiter if delimiter == "\\n": delimiter = "\n" - + if not all(isinstance(s, str) for s in strings): raise InputValueError(self, "strings", "All items must be strings") - + result = delimiter.join(strings) self.set_output_values({"result": result}) + @register("data/string/Replace") class Replace(Node): """Replaces occurrences of a substring with another - + Searches for all occurrences of a substring and replaces them with a new string. - + Inputs: - + - string: The original string - old: Substring to find and replace - new: Replacement string - + Properties: - + - count: Maximum number of replacements to make (-1 for all occurrences) - + Outputs: - + - result: The string after replacements """ - + class Fields: count = PropertyField( name="count", description="Maximum number of replacements to make (-1 for all occurrences)", type="int", - default=-1 + default=-1, ) - + def setup(self): self.add_input("string", socket_type="str") self.add_input("old", socket_type="str") self.add_input("new", socket_type="str") self.add_output("result", socket_type="str") - + self.set_property("count", -1) # -1 means replace all - + async def run(self, state: GraphState): string = self.get_input_value("string") old = self.get_input_value("old") new = self.get_input_value("new") count = self.get_property("count") - + result = string.replace(old, new, count) self.set_output_values({"result": result}) + @register("data/string/Format") class Format(Node): """Python-style string formatting with variables - + Formats a template string by replacing placeholders with values from a variables dictionary, using Python's format() string method. - + Inputs: - + - template: A format string with placeholders (e.g., "Hello, {name}") - variables: Dictionary of variable names and values to insert - + Outputs: - + - result: The formatted string """ - + def setup(self): self.add_input("template", socket_type="str") self.add_input("variables", socket_type="dict") self.add_output("result", socket_type="str") - + async def run(self, state: GraphState): template = self.get_input_value("template") variables = self.get_input_value("variables") - + try: result = template.format(**variables) self.set_output_values({"result": result}) except (KeyError, ValueError) as e: raise InputValueError(self, "variables", f"Format error: {str(e)}") + @register("data/string/Case") class Case(Node): """Changes string case (upper, lower, title, capitalize) - + Converts a string to a different case format, such as uppercase, lowercase, title case, or capitalized. - + Inputs: - + - string: The string to transform - + Properties: - + - operation: Case operation to perform (upper, lower, title, capitalize) - + Outputs: - + - result: The transformed string """ - + class Fields: operation = PropertyField( name="operation", description="Case operation to perform", type="str", default="lower", - choices=["upper", "lower", "title", "capitalize"] + choices=["upper", "lower", "title", "capitalize"], ) - + def setup(self): self.add_input("string", socket_type="str") self.add_output("result", socket_type="str") - + self.set_property("operation", "lower") - + async def run(self, state: GraphState): string = self.get_input_value("string") operation = self.get_property("operation") - + if operation == "upper": result = string.upper() elif operation == "lower": @@ -293,95 +299,94 @@ class Case(Node): result = string.title() elif operation == "capitalize": result = string.capitalize() - + self.set_output_values({"result": result}) + @register("data/string/Trim") class Trim(Node): """Removes characters from start/end of string - + Removes specified characters from the beginning, end, or both ends of a string. By default, it removes whitespace if no specific characters are provided. - + Inputs: - + - string: The string to trim - chars: Character(s) to remove (optional, defaults to whitespace) - + Properties: - + - mode: Where to trim from (left, right, both) - chars: Default characters to remove when not provided via input - + Outputs: - + - result: The trimmed string """ - + class Fields: mode = PropertyField( name="mode", description="Trim mode", type="str", default="both", - choices=["left", "right", "both"] + choices=["left", "right", "both"], ) - + chars = PropertyField( - name="chars", - description="Character(s) to remove", - type="str", - default=None + name="chars", description="Character(s) to remove", type="str", default=None ) - + def setup(self): self.add_input("string", socket_type="str") self.add_input("chars", socket_type="str", optional=True) self.add_output("result", socket_type="str") - + self.set_property("mode", "both") self.set_property("chars", None) # None means whitespace - + async def run(self, state: GraphState): string = self.get_input_value("string") chars = self.get_input_value("chars") mode = self.get_property("mode") - + # handle escaped newline chars if chars and "\\n" in chars: chars = chars.replace("\\n", "\n") - + if mode == "left": result = string.lstrip(chars) elif mode == "right": result = string.rstrip(chars) else: result = string.strip(chars) - + self.set_output_values({"result": result}) + @register("data/string/Substring") class Substring(Node): """Extracts a portion of a string using indices - + Extracts a substring from the original string using start and end indices. - + Inputs: - + - string: The source string - start: Starting index (optional) - end: Ending index (optional) - + Properties: - + - start: Default starting index (0-based) - end: Default ending index (None means until the end of the string) - + Outputs: - + - result: The extracted substring """ - + class Fields: start = PropertyField( name="start", @@ -389,183 +394,178 @@ class Substring(Node): type="int", default=0, min=0, - step=1 + step=1, ) - + end = PropertyField( name="end", description="Ending index", type="int", default=None, min=0, - step=1 + step=1, ) - + def setup(self): self.add_input("string", socket_type="str") self.add_input("start", socket_type="int", optional=True) self.add_input("end", socket_type="int", optional=True) self.add_output("result", socket_type="str") - + self.set_property("start", 0) self.set_property("end", None) - + async def run(self, state: GraphState): string = self.get_input_value("string") start = self.get_input_value("start") end = self.get_input_value("end") - + result = string[start:end] self.set_output_values({"result": result}) + @register("data/string/Extract") class Extract(Node): """ Extracts a portion of a string using a left and right anchor - + Whatever is between the left and right anchors will be extracted. - + The first occurrence of the left anchor will be used. - + Inputs: - + - string: The string to extract from - left_anchor: The left anchor - right_anchor: The right anchor - + Properties: - + - left_anchor: The left anchor - right_anchor: The right anchor - trim: Whether to trim the result - + Outputs: - + - result: The extracted substring """ - + class Fields: left_anchor = PropertyField( - name="left_anchor", - description="The left anchor", - type="str", - default="" + name="left_anchor", description="The left anchor", type="str", default="" ) right_anchor = PropertyField( - name="right_anchor", - description="The right anchor", - type="str", - default="" + name="right_anchor", description="The right anchor", type="str", default="" ) trim = PropertyField( name="trim", description="Whether to trim the result", type="bool", - default=True + default=True, ) - + def __init__(self, title="Extract", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("string", socket_type="str") self.add_input("left_anchor", socket_type="str", optional=True) self.add_input("right_anchor", socket_type="str", optional=True) - + self.set_property("left_anchor", "") self.set_property("right_anchor", "") self.set_property("trim", True) - + self.add_output("result", socket_type="str") - + async def run(self, state: GraphState): string = self.get_input_value("string") left_anchor = self.normalized_input_value("left_anchor") or "" right_anchor = self.normalized_input_value("right_anchor") or "" trim = self.normalized_input_value("trim") - + parts = string.split(left_anchor, 1) if len(parts) > 1: result = parts[1].split(right_anchor, 1)[0] else: result = "" - + if trim: result = result.strip() - + self.set_output_values({"result": result}) - + @register("data/string/StringCheck") class StringCheck(Node): """Checks if a string starts with, ends with, or contains a substring - + Tests whether a string starts with, ends with, contains, or exactly equals a substring, with optional case sensitivity. - + Inputs: - + - string: The string to check - substring: The substring to look for - + Properties: - + - mode: Check operation to perform (startswith, endswith, contains, exact) - case_sensitive: Whether the check should be case-sensitive - substring: Default substring to check for when not provided via input - + Outputs: - + - result: Boolean result of the check """ - + class Fields: mode = PropertyField( name="mode", description="Check operation to perform", type="str", default="contains", - choices=["startswith", "endswith", "contains", "exact"] + choices=["startswith", "endswith", "contains", "exact"], ) case_sensitive = PropertyField( name="case_sensitive", description="Whether the check should be case-sensitive", type="bool", - default=True + default=True, ) substring = PropertyField( name="substring", description="Default substring to check for", type="str", - default="" + default="", ) - + def __init__(self, title="String Check", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("string", socket_type="str") self.add_input("substring", socket_type="str") self.add_output("result", socket_type="bool") - + self.set_property("substring", "") self.set_property("mode", "contains") self.set_property("case_sensitive", True) - + async def run(self, state: GraphState): string = self.get_input_value("string") substring = self.get_input_value("substring") mode = self.get_property("mode") case_sensitive = self.get_property("case_sensitive") - + if not string: self.set_output_values({"result": False}) return - + if not case_sensitive: string = string.lower() substring = substring.lower() - + if mode == "startswith": result = string.startswith(substring) elif mode == "endswith": @@ -574,5 +574,5 @@ class StringCheck(Node): result = string == substring else: # contains result = substring in string - - self.set_output_values({"result": result}) \ No newline at end of file + + self.set_output_values({"result": result}) diff --git a/src/talemate/game/engine/nodes/util.py b/src/talemate/game/engine/nodes/util.py index 21ac9ee6..4091f9dd 100644 --- a/src/talemate/game/engine/nodes/util.py +++ b/src/talemate/game/engine/nodes/util.py @@ -7,28 +7,29 @@ from .core import ( PropertyField, ) + @register("util/Counter") class Counter(Node): """ - Counter node that increments a numeric value inside a + Counter node that increments a numeric value inside a dict and returns the new value. - + Inputs: - state: The graph state - dict: The dict containing the value to increment - key: The key to the value to increment - reset: If true, the value will be reset to 0 - + Properties: - increment: The amount to increment the value by - key: The key to the value to increment - reset: If true, the value will be reset to 0 - + Outputs: - value: The new value - dict: The dict with the new value """ - + class Fields: increment = PropertyField( name="increment", @@ -36,55 +37,51 @@ class Counter(Node): default=1, step=1, min=1, - description="The amount to increment the value by" + description="The amount to increment the value by", ) - + key = PropertyField( name="key", type="str", default="counter", - description="The key to the value to increment" + description="The key to the value to increment", ) - + reset = PropertyField( name="reset", type="bool", default=False, - description="If true, the value will be reset to 0" + description="If true, the value will be reset to 0", ) - + def __init__(self, title="Counter", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_input("dict", socket_type="dict") self.add_input("key", socket_type="str", optional=True) self.add_input("reset", socket_type="bool", optional=True) - + self.set_property("increment", 1) self.set_property("key", "counter") self.set_property("reset", False) - + self.add_output("value") self.add_output("dict", socket_type="dict") - + async def run(self, state: GraphState): dict_ = self.get_input_value("dict") key = self.get_input_value("key") reset = self.get_input_value("reset") increment = self.get_property("increment") - + if increment is UNRESOLVED: raise InputValueError(self, "increment", "Increment value is required") - + if reset: dict_[key] = 0 else: dict_[key] = dict_.get(key, 0) + increment - - self.set_output_values({ - "value": dict_[key], - "dict": dict_ - }) - \ No newline at end of file + + self.set_output_values({"value": dict_[key], "dict": dict_}) diff --git a/src/talemate/game/engine/nodes/world_state.py b/src/talemate/game/engine/nodes/world_state.py index bb2c5b43..95086598 100644 --- a/src/talemate/game/engine/nodes/world_state.py +++ b/src/talemate/game/engine/nodes/world_state.py @@ -1,12 +1,6 @@ import structlog from typing import TYPE_CHECKING -from .core import ( - Node, - GraphState, - UNRESOLVED, - PropertyField, - TYPE_CHOICES -) +from .core import Node, GraphState, UNRESOLVED, PropertyField, TYPE_CHOICES from .registry import register from talemate.context import active_scene from talemate.world_state.manager import WorldStateManager @@ -18,263 +12,250 @@ if TYPE_CHECKING: log = structlog.get_logger("talemate.game.engine.nodes.scene") # extend TYPE_CHOICES with GenerationOptions -TYPE_CHOICES.extend([ - "generation_options", - "spices", - "writing_style" -]) +TYPE_CHOICES.extend(["generation_options", "spices", "writing_style"]) + class WorldStateManagerNode(Node): - """ Base class for world state manager nodes """ - + @property def world_state_manager(self) -> WorldStateManager: - scene:"Scene" = active_scene.get() + scene: "Scene" = active_scene.get() return scene.world_state_manager - + @register("scene/worldstate/SaveWorldEntry") class SaveWorldEntry(WorldStateManagerNode): - """ Saves the world entry - + Inputs: - + - entry_id: The id of the world entry - text: The text of the world entry - meta: The meta of the world entry - + Properties: - + - create_pin: Whether to create a pin for the entry - + Outputs: - + - entry_id: The id of the world entry - text: The text of the world entry - meta: The meta of the world entry """ - + class Fields: entry_id = PropertyField( name="entry_id", description="The id of the world entry", type="str", - default=UNRESOLVED + default=UNRESOLVED, ) - + text = PropertyField( name="text", description="The text of the world entry", type="text", - default=UNRESOLVED + default=UNRESOLVED, ) - + meta = PropertyField( name="meta", description="The meta of the world entry", type="dict", - default={} + default={}, ) - + create_pin = PropertyField( name="create_pin", description="Whether to create a pin for the entry", type="bool", - default=False + default=False, ) - + def __init__(self, title="Save World Entry", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("entry_id", socket_type="str", optional=True) self.add_input("text", socket_type="str", optional=True) self.add_input("meta", socket_type="dict", optional=True) - + self.set_property("entry_id", UNRESOLVED) self.set_property("text", UNRESOLVED) self.set_property("meta", {}) self.set_property("create_pin", False) - + self.add_output("entry_id", socket_type="str") self.add_output("text", socket_type="str") self.add_output("meta", socket_type="dict") - - + async def run(self, state: GraphState): - entry_id = self.require_input("entry_id") text = self.require_input("text") meta = self.get_input_value("meta") create_pin = self.get_property("create_pin") - - await self.world_state_manager.save_world_entry(entry_id, text, meta, create_pin) - - self.set_output_values({ - "entry_id": entry_id, - "text": text, - "meta": meta - }) - + + await self.world_state_manager.save_world_entry( + entry_id, text, meta, create_pin + ) + + self.set_output_values({"entry_id": entry_id, "text": text, "meta": meta}) + + # WORLD STATE TEMPLATES + @register("scene/worldstate/templates/Spices") class Spices(Node): """ Node that returns a Spices object - + Inputs: - + - spice_values: list of strings - + Outputs: - + - spices: list of strings """ - + class Fields: spice_values = PropertyField( name="spice_values", description="The list of spices", type="list", - default=[] + default=[], ) + def __init__(self, title="Spices", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("spice_values", socket_type="list", optional=True) - + self.set_property("spice_values", []) - + self.add_output("spices", socket_type="list") - + async def run(self, state: GraphState): spice_values = self.get_input_value("spice_values") - + spices = content.Spices(spices=spice_values) - - self.set_output_values({ - "spices": spices - }) - + + self.set_output_values({"spices": spices}) + + @register("scene/worldstate/templates/WritingStyle") class WritingStyle(Node): """ Node that returns a WritingStyle object - + Inputs: - + - instructions: Writing style instructions - + Outputs: - + - writing_style: The writing style to apply to the generation options """ - + class Fields: instructions = PropertyField( name="instructions", description="Writing style instructions", type="text", - default=UNRESOLVED + default=UNRESOLVED, ) - + def __init__(self, title="Writing Style", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("instructions", socket_type="str", optional=True) - + self.set_property("instructions", "") - + self.add_output("writing_style", socket_type="writing_style") - + async def run(self, state: GraphState): instructions = self.require_input("instructions") - + writing_style = content.WritingStyle(instructions=instructions) - - self.set_output_values({ - "writing_style": writing_style - }) + + self.set_output_values({"writing_style": writing_style}) + @register("scene/worldstate/templates/GenerationOptions") class GenerationOptions(Node): """ Node that returns a GenerationOptions object - + Inputs: - + - spices: The spices to apply to the generation options - spice_level: The spice level to apply to the generation options - writing_style: The writing style to apply to the generation options - + Properties: - + - spice_level: The spice level to apply to the generation options - writing_style: The writing style to apply to the generation options - + Outputs: - + - generation_options: The generation options """ - + class Fields: spices = PropertyField( name="spices", description="The spices to apply to the generation options", type="spices", - default=UNRESOLVED + default=UNRESOLVED, ) - + spice_level = PropertyField( name="spice_level", description="The spice level to apply to the generation options", type="number", default=0.0, - min=0.0, + min=0.0, max=1.0, - step=0.1 + step=0.1, ) - + writing_style = PropertyField( name="writing_style", description="The writing style to apply to the generation options", type="writing_style", - default=UNRESOLVED + default=UNRESOLVED, ) - + def __init__(self, title="Generation Options", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("spices", socket_type="generation_options", optional=True) self.add_input("spice_level", socket_type="number", optional=True) self.add_input("writing_style", socket_type="writing_style", optional=True) - + self.set_property("spice_level", 0.0) self.set_property("writing_style", "") - + self.add_output("generation_options", socket_type="generation_options") - + async def run(self, state: GraphState): spices = self.normalized_input_value("spices") spice_level = self.normalized_input_value("spice_level") writing_style = self.normalized_input_value("writing_style") - + generation_options = content.GenerationOptions( - spices=spices, - spice_level=spice_level, - writing_style=writing_style + spices=spices, spice_level=spice_level, writing_style=writing_style ) - - self.set_output_values({ - "generation_options": generation_options - }) - \ No newline at end of file + + self.set_output_values({"generation_options": generation_options}) diff --git a/src/talemate/game/engine/scene_loop.py b/src/talemate/game/engine/scene_loop.py index b79c85a3..49cf75fb 100644 --- a/src/talemate/game/engine/scene_loop.py +++ b/src/talemate/game/engine/scene_loop.py @@ -1,15 +1,13 @@ from typing import TYPE_CHECKING -from .nodes.core import Graph - if TYPE_CHECKING: from talemate.tale_mate import Scene - - -async def run_scene(scene: 'Scene'): + + +async def run_scene(scene: "Scene"): """ Talemate v2 scene loop based on node graph """ - - pass \ No newline at end of file + + pass diff --git a/src/talemate/game/focal/__init__.py b/src/talemate/game/focal/__init__.py index 9afca189..0aea270b 100644 --- a/src/talemate/game/focal/__init__.py +++ b/src/talemate/game/focal/__init__.py @@ -17,7 +17,7 @@ from talemate.util.data import ( extract_data, ) -from .schema import Argument, Call, Callback, State, ExampleCallbackArguments +from .schema import Argument, Call, Callback, State __all__ = [ "Argument", @@ -33,171 +33,175 @@ log = structlog.get_logger("talemate.game.focal") current_focal_context = ContextVar("current_focal_context", default=None) + class FocalContext: def __init__(self): self.hooks_before_call = [] self.hooks_after_call = [] self.value = {} - + def __enter__(self): self.token = current_focal_context.set(self) return self - + def __exit__(self, *args): current_focal_context.reset(self.token) - - async def process_hooks(self, call:Call): + + async def process_hooks(self, call: Call): for hook in self.hooks_after_call: await hook(call) + class Focal: - schema_formats: list[str] = ["json", "yaml"] - + def __init__( - self, + self, client: ClientBase, callbacks: list[Callback], max_calls: int = 5, retries: int = 0, schema_format: str = "json", - **kwargs + **kwargs, ): self.client = client self.context = kwargs self.max_calls = max_calls self.retries = retries self.state = State(schema_format=schema_format) - self.callbacks = { - callback.name: callback - for callback in callbacks - } - + self.callbacks = {callback.name: callback for callback in callbacks} + # set state on each callback for callback in self.callbacks.values(): callback.state = self.state - + def render_instructions(self) -> str: prompt = Prompt.get( "focal.instructions", { "max_calls": self.max_calls, "state": self.state, - } + }, ) return prompt.render() - + async def request( self, template_name: str, retry_state: dict | None = None, ) -> str: - - log.debug("focal.request", template_name=template_name, callbacks=self.callbacks) - + log.debug( + "focal.request", template_name=template_name, callbacks=self.callbacks + ) + # client preference for schema format if self.client.data_format: self.state.schema_format = self.client.data_format - + response = await Prompt.request( template_name, self.client, "analyze_long", vars={ - **self.context, + **self.context, "focal": self, - "max_tokens":self.client.max_token_length, + "max_tokens": self.client.max_token_length, "max_calls": self.max_calls, }, dedupe_enabled=False, ) - + response = response.strip() - + if not retry_state: retry_state = {"retries": self.retries} - + if not response: log.warning("focal.request.empty_response") - + log.debug("focal.request", template_name=template_name, response=response) - + if response: await self._execute(response, State()) - + # if no calls were made and we still have retries, try again if not self.state.calls and retry_state["retries"] > 0: - log.warning("focal.request - NO CALLS MADE - retrying", retries=retry_state["retries"]) + log.warning( + "focal.request - NO CALLS MADE - retrying", + retries=retry_state["retries"], + ) retry_state["retries"] -= 1 return await self.request(template_name, retry_state) - + return response - + async def _execute(self, response: str, state: State): try: calls: list[Call] = await self._extract(response) except Exception as e: log.error("focal.extract_error", error=str(e)) return - + focal_context = current_focal_context.get() - + calls_made = 0 - + for call in calls: - if calls_made >= self.max_calls: log.warning("focal.execute.max_calls_reached", max_calls=self.max_calls) break - + if call.name not in self.callbacks: log.warning("focal.execute.unknown_callback", name=call.name) continue - + callback = self.callbacks[call.name] - + try: - # if we have a focal context, process additional hooks (before call) if focal_context: await focal_context.process_hooks(call) - - log.debug(f"focal.execute - Calling {callback.name}", arguments=call.arguments) + + log.debug( + f"focal.execute - Calling {callback.name}", arguments=call.arguments + ) result = await callback.fn(**call.arguments) call.result = result call.called = True calls_made += 1 - + # if we have a focal context, process additional hooks (after call) if focal_context: await focal_context.process_hooks(call) - - except Exception as e: + + except Exception: log.error( "focal.execute.callback_error", callback=call.name, error=traceback.format_exc(), ) - + self.state.calls.append(call) - - async def _extract(self, response:str) -> list[Call]: - + + async def _extract(self, response: str) -> list[Call]: # first try to extract data from the response using tooling try: data = extract_data(response, self.state.schema_format) return [Call(**call) for call in data] except Exception as e: - log.warning("focal.extract.data FAILED - attempting to use AI to extract calls", error=str(e)) - + log.warning( + "focal.extract.data FAILED - attempting to use AI to extract calls", + error=str(e), + ) + # if there is no JSON structure in the response, there are no calls to extract # so we return an empty list if f"```{self.state.schema_format}" not in response: log.warning("focal.extract.no_json_structure") return [] - + log.debug("focal.extract", response=response) - + _, calls_json = await Prompt.request( "focal.extract_calls", self.client, @@ -212,36 +216,37 @@ class Focal: ) calls = [Call(**call) for call in calls_json.get("calls", [])] - + log.debug("focal.extract", calls=calls) - + return calls - - -def collect_calls(calls:list[Call], nested:bool=False, filter: Callable=None) -> list: - + + +def collect_calls( + calls: list[Call], nested: bool = False, filter: Callable = None +) -> list: """ Takes a list of calls and collects into a list. - + If nested is True and call result is a list of calls, it will also collect those. - + If a filter function is provided, it will be used to filter the results. """ - + results = [] - + for call in calls: - - result_is_list_of_calls = isinstance(call.result, list) and all([isinstance(result, Call) for result in call.result]) - + result_is_list_of_calls = isinstance(call.result, list) and all( + [isinstance(result, Call) for result in call.result] + ) + # we need to filter the results # but if nested is True, we need to collect nested results regardless - + if not filter or filter(call): results.append(call) - + if nested and result_is_list_of_calls: results.extend(collect_calls(call.result, nested=True, filter=filter)) - - - return results \ No newline at end of file + + return results diff --git a/src/talemate/game/focal/schema.py b/src/talemate/game/focal/schema.py index c92fe825..4ab90217 100644 --- a/src/talemate/game/focal/schema.py +++ b/src/talemate/game/focal/schema.py @@ -7,12 +7,12 @@ import yaml from talemate.prompts.base import Prompt __all__ = [ - "Argument", - "Call", - "Callback", - "State", + "Argument", + "Call", + "Callback", + "State", "InvalidCallbackArguments", - "ExampleCallbackArguments" + "ExampleCallbackArguments", ] YAML_OPTIONS = { @@ -23,52 +23,64 @@ YAML_OPTIONS = { "width": 100, } -YAML_PRESERVE_NEWLINES = "If there are newlines, they should be preserved by using | style." +YAML_PRESERVE_NEWLINES = ( + "If there are newlines, they should be preserved by using | style." +) + class InvalidCallbackArguments(ValueError): pass + class ExampleCallbackArguments(InvalidCallbackArguments): pass + class State(pydantic.BaseModel): - calls:list["Call"] = pydantic.Field(default_factory=list) + calls: list["Call"] = pydantic.Field(default_factory=list) schema_format: Literal["json", "yaml"] = "json" + class Argument(pydantic.BaseModel): name: str type: str preserve_newlines: bool = False - - def extra_instructions(self, state:State) -> str: + + def extra_instructions(self, state: State) -> str: if state.schema_format == "yaml" and self.preserve_newlines: return f" {YAML_PRESERVE_NEWLINES}" return "" - + + class Call(pydantic.BaseModel): - name: str = pydantic.Field(validation_alias=pydantic.AliasChoices('name', 'function')) + name: str = pydantic.Field( + validation_alias=pydantic.AliasChoices("name", "function") + ) arguments: dict[str, Any] = pydantic.Field(default_factory=dict) result: str | int | float | bool | dict | list | None = None uid: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) called: bool = False - - @pydantic.field_validator('arguments') + + @pydantic.field_validator("arguments") def check_for_schema_examples(cls, v: dict[str, Any]) -> dict[str, str]: valid_types = ["str", "int", "float", "bool", "dict", "list"] for key, value in v.items(): if isinstance(value, str): for type_name in valid_types: if value.startswith(f"{type_name} - "): - raise ExampleCallbackArguments(f"Argument '{key}' contains schema example: '{value}'. AI repeated the schema format.") + raise ExampleCallbackArguments( + f"Argument '{key}' contains schema example: '{value}'. AI repeated the schema format." + ) return v - - @pydantic.field_validator('arguments') + + @pydantic.field_validator("arguments") def join_string_lists(cls, v: dict[str, Any]) -> dict[str, str]: return { - key: '\n'.join(str(item) for item in value) if isinstance(value, list) else str(value) + key: "\n".join(str(item) for item in value) + if isinstance(value, list) + else str(value) for key, value in v.items() } - class Callback(pydantic.BaseModel): @@ -77,12 +89,12 @@ class Callback(pydantic.BaseModel): fn: Callable state: State = State() multiple: bool = True - + @property def pretty_name(self) -> str: return self.name.replace("_", " ").title() - - def render(self, usage:str, examples:list[dict]=None, **argument_usage) -> str: + + def render(self, usage: str, examples: list[dict] = None, **argument_usage) -> str: prompt = Prompt.get( "focal.callback", { @@ -93,52 +105,52 @@ class Callback(pydantic.BaseModel): "arguments": self.arguments, "state": self.state, "examples": examples or [], - } + }, ) - + return prompt.render() - + ## schema - + def _usage(self, argument_usage) -> dict: return { "function": self.name, "arguments": { argument.name: f"{argument.type} - {argument_usage.get(argument.name, '')}{argument.extra_instructions(self.state)}" for argument in self.arguments - } - } - - def _example(self, example:dict) -> dict: + }, + } + + def _example(self, example: dict) -> dict: return { "function": self.name, - "arguments": {k:v for k,v in example.items() if not k.startswith("_")}, + "arguments": {k: v for k, v in example.items() if not k.startswith("_")}, } - + def usage(self, argument_usage) -> str: - fmt:str = self.state.schema_format + fmt: str = self.state.schema_format text = getattr(self, f"{fmt}_usage")(argument_usage) text = text.rstrip() return f"```{fmt}\n{text}\n```" - - def example(self, example:dict) -> str: - fmt:str = self.state.schema_format + + def example(self, example: dict) -> str: + fmt: str = self.state.schema_format text = getattr(self, f"{fmt}_example")(example) text = text.rstrip() return f"```{fmt}\n{text}\n```" - + ## JSON - + def json_usage(self, argument_usage) -> str: return json.dumps(self._usage(argument_usage), indent=2) - - def json_example(self, example:dict) -> str: + + def json_example(self, example: dict) -> str: return json.dumps(self._example(example), indent=2) - + ## YAML - + def yaml_usage(self, argument_usage) -> str: return yaml.dump(self._usage(argument_usage), **YAML_OPTIONS) - - def yaml_example(self, example:dict) -> str: - return yaml.dump(self._example(example), **YAML_OPTIONS) \ No newline at end of file + + def yaml_example(self, example: dict) -> str: + return yaml.dump(self._example(example), **YAML_OPTIONS) diff --git a/src/talemate/game/scope.py b/src/talemate/game/scope.py index b8c7d9ef..37e6ed85 100644 --- a/src/talemate/game/scope.py +++ b/src/talemate/game/scope.py @@ -7,7 +7,6 @@ import structlog import talemate.game.engine.api as scoped_api from talemate.client.base import ClientBase from talemate.emit import emit -from talemate.instance import get_agent from talemate.exceptions import GenerationCancelled from talemate.context import handle_generation_cancelled @@ -48,7 +47,6 @@ class OpenScopedContext: class GameInstructionScope: - def __init__( self, director: "DirectorAgent", diff --git a/src/talemate/game/state.py b/src/talemate/game/state.py index b47126a0..b08664dc 100644 --- a/src/talemate/game/state.py +++ b/src/talemate/game/state.py @@ -49,28 +49,28 @@ class GameState(pydantic.BaseModel): @property def game_won(self) -> bool: - return self.variables.get("__game_won__") == True - + return self.variables.get("__game_won__") is True + def __getitem__(self, key: str) -> Any: return self.get_var(key) - + def __setitem__(self, key: str, value: Any): self.set_var(key, value) - + def __delitem__(self, key: str): return self.unset_var(key) - + def __contains__(self, key: str) -> bool: return self.has_var(key) - + def get(self, key: str, default: Any = None) -> Any: return self.get_var(key, default=default) - + def pop(self, key: str, default: Any = None) -> Any: value = self.get_var(key, default=default) self.unset_var(key) return value - + def set_var(self, key: str, value: Any, commit: bool = False): self.variables[key] = value if commit: diff --git a/src/talemate/history.py b/src/talemate/history.py index 7b34bb13..d6dfebec 100644 --- a/src/talemate/history.py +++ b/src/talemate/history.py @@ -13,20 +13,23 @@ import traceback import uuid import datetime import isodate -import math from talemate.emit import emit import talemate.emit.async_signals as async_signals from talemate.instance import get_agent from talemate.scene_message import SceneMessage -from talemate.util import iso8601_diff_to_human, iso8601_add, duration_to_timedelta, timedelta_to_duration +from talemate.util import ( + iso8601_diff_to_human, + iso8601_add, + duration_to_timedelta, +) from talemate.world_state.templates import GenerationOptions from talemate.exceptions import GenerationCancelled from talemate.context import handle_generation_cancelled from talemate.events import ArchiveEvent if TYPE_CHECKING: - from talemate.tale_mate import Scene, Character + from talemate.tale_mate import Scene __all__ = [ "history_with_relative_time", @@ -49,9 +52,11 @@ log = structlog.get_logger() async_signals.register("archive_add") + class UnregeneratableEntryError(Exception): pass + class ArchiveEntry(pydantic.BaseModel): text: str id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())[:8]) @@ -59,10 +64,12 @@ class ArchiveEntry(pydantic.BaseModel): end: int | None = None ts: str = pydantic.Field(default_factory=lambda: "PT1S") + class LayeredArchiveEntry(ArchiveEntry): ts_start: str | None = None ts_end: str | None = None - + + class HistoryEntry(pydantic.BaseModel): text: str ts: str @@ -87,78 +94,97 @@ class SourceEntry(pydantic.BaseModel): ts: str | None = None ts_start: str | None = None ts_end: str | None = None - + def __str__(self): return self.text + async def emit_archive_add(scene: "Scene", entry: ArchiveEntry): """ Emits the archive_add signal for an archive entry """ await async_signals.get("archive_add").send( ArchiveEvent( - scene=scene, - event_type="archive_add", - text=entry.text, - ts=entry.ts, - memory_id=entry.id + scene=scene, + event_type="archive_add", + text=entry.text, + ts=entry.ts, + memory_id=entry.id, ) ) -def resolve_history_entry(scene: "Scene", entry: HistoryEntry) -> LayeredArchiveEntry | ArchiveEntry: + +def resolve_history_entry( + scene: "Scene", entry: HistoryEntry +) -> LayeredArchiveEntry | ArchiveEntry: """ Resolves a history entry in the scene's archived history """ - + if entry.layer == 0: return ArchiveEntry(**scene.archived_history[entry.index]) else: - return LayeredArchiveEntry(**scene.layered_history[entry.layer - 1][entry.index]) + return LayeredArchiveEntry( + **scene.layered_history[entry.layer - 1][entry.index] + ) -def entry_contained(scene: "Scene", entry_id: str, container: HistoryEntry | SourceEntry) -> bool: + +def entry_contained( + scene: "Scene", entry_id: str, container: HistoryEntry | SourceEntry +) -> bool: """ Checks if entry_id is contained in container through source entries, checking all the way up to the base layer """ messages = collect_source_entries(scene, container) - + for message in messages: if message.id == entry_id: return True - if not isinstance(message, SceneMessage) and entry_contained(scene, entry_id, message): + if not isinstance(message, SceneMessage) and entry_contained( + scene, entry_id, message + ): return True - + return False + def collect_source_entries(scene: "Scene", entry: HistoryEntry) -> list[SourceEntry]: """ Collects the source entries for a history entry """ - + if entry.start is None or entry.end is None: # entries that dont defien a start and end are not regeneratable return [] - + if entry.layer == 0: - # base layer + # base layer def include_message(message: SceneMessage) -> bool: - return message.typ not in ["director", "context_investigation", "reinforcement"] - + return message.typ not in [ + "director", + "context_investigation", + "reinforcement", + ] + result = [ SourceEntry( - text=str(source), - layer=-1, + text=str(source), + layer=-1, id=source.id, start=entry.start, end=entry.end, ts=source.ts, ts_start=source.ts_start, - ts_end=source.ts_end) for source in filter(include_message, scene.history[entry.start:entry.end+1] + ts_end=source.ts_end, + ) + for source in filter( + include_message, scene.history[entry.start : entry.end + 1] ) ] - + return result - + else: # layered history if entry.layer == 1: @@ -167,20 +193,20 @@ def collect_source_entries(scene: "Scene", entry: HistoryEntry) -> list[SourceEn else: source_layer_index = entry.layer - 1 source_layer = scene.layered_history[source_layer_index] - + return [ SourceEntry( - text=source["text"], - layer=source_layer_index, + text=source["text"], + layer=source_layer_index, id=source["id"], start=source.get("start", None), end=source.get("end", None), ts=source.get("ts", None), ts_start=source.get("ts_start", None), ts_end=source.get("ts_end", None), - ) for source in source_layer[entry.start:entry.end+1] + ) + for source in source_layer[entry.start : entry.end + 1] ] - def pop_history( @@ -219,7 +245,9 @@ def pop_history( history.remove(message) -def history_with_relative_time(history: list[str], scene_time: str, layer: int = 0) -> list[dict]: +def history_with_relative_time( + history: list[str], scene_time: str, layer: int = 0 +) -> list[dict]: """ Cycles through a list of Archived History entries and runs iso8601_diff_to_human @@ -240,14 +268,19 @@ def history_with_relative_time(history: list[str], scene_time: str, layer: int = ts_start=entry.get("ts_start", None), ts_end=entry.get("ts_end", None), time=iso8601_diff_to_human(scene_time, entry["ts"]), - time_start=iso8601_diff_to_human(scene_time, entry["ts_start"] if entry.get("ts_start") else None), - time_end=iso8601_diff_to_human(scene_time, entry["ts_end"] if entry.get("ts_end") else None), + time_start=iso8601_diff_to_human( + scene_time, entry["ts_start"] if entry.get("ts_start") else None + ), + time_end=iso8601_diff_to_human( + scene_time, entry["ts_end"] if entry.get("ts_end") else None + ), start=entry.get("start", None), end=entry.get("end", None), ).model_dump() for index, entry in enumerate(history) ] + async def purge_all_history_from_memory(): """ Removes all history from the memory agent @@ -255,6 +288,7 @@ async def purge_all_history_from_memory(): memory = get_agent("memory") await memory.delete({"typ": "history"}) + async def rebuild_history( scene: "Scene", callback: Callable | None = None, @@ -263,14 +297,13 @@ async def rebuild_history( """ rebuilds all history for a scene """ - memory = get_agent("memory") summarizer = get_agent("summarizer") # clear out archived history, but keep pre-established history scene.archived_history = [ ah for ah in scene.archived_history if ah.get("end") is None ] - + scene.layered_history = [] await purge_all_history_from_memory() @@ -284,7 +317,6 @@ async def rebuild_history( try: while True: - await asyncio.sleep(0.1) if not scene.active: @@ -317,107 +349,122 @@ async def rebuild_history( emit("status", message="Rebuilding of archive cancelled", status="info") handle_generation_cancelled(e) return - except Exception as e: + except Exception: log.error("Error rebuilding historical archive", error=traceback.format_exc()) emit("status", message="Error rebuilding historical archive", status="error") return scene.sync_time() await scene.commit_to_memory() - + if summarizer.layered_history_enabled: emit("status", message="Rebuilding layered history...", status="busy") await summarizer.summarize_to_layered_history() - + emit("status", message="Historical archive rebuilt", status="success") class CharacterActivity(pydantic.BaseModel): - none_have_acted:bool - characters:list + none_have_acted: bool + characters: list -async def character_activity(scene: "Scene", since_time_passage: bool = False) -> CharacterActivity: + +async def character_activity( + scene: "Scene", since_time_passage: bool = False +) -> CharacterActivity: """ Returns a CharacterActivity object containing a list of all active characters sorted by which were last active - + The most recently active character is first in the list. - + If no characters have acted, the none_have_acted flag will be set to True. - + If since_time_passage is True, the search will stop when a TimePassageMessage is found. """ - - activity:list = [] - + + activity: list = [] + character_names = scene.character_names - - for message in scene.collect_messages(typ="character", max_iterations=100, stop_on_time_passage=since_time_passage): - if message.character_name not in activity and message.character_name in character_names: + + for message in scene.collect_messages( + typ="character", max_iterations=100, stop_on_time_passage=since_time_passage + ): + if ( + message.character_name not in activity + and message.character_name in character_names + ): activity.append(message.character_name) - + # if all characters have been added, break if len(activity) == len(character_names): break - + none_have_acted = not activity - + # any characters in the activity list at this point have not spoken # and should be appended to the list for character in character_names: if character not in activity: activity.append(character) - + return CharacterActivity( none_have_acted=none_have_acted, - characters=[scene.get_character(character) for character in activity] + characters=[scene.get_character(character) for character in activity], ) - - -async def update_history_entry(scene: "Scene", entry: HistoryEntry) -> LayeredArchiveEntry | ArchiveEntry: + + +async def update_history_entry( + scene: "Scene", entry: HistoryEntry +) -> LayeredArchiveEntry | ArchiveEntry: """ Updates a history entry in the scene's archived history """ - + if entry.layer == 0: # base layer archive_entry = ArchiveEntry(**entry.model_dump()) - scene.archived_history[entry.index] = archive_entry.model_dump(exclude_none=True) + scene.archived_history[entry.index] = archive_entry.model_dump( + exclude_none=True + ) await emit_archive_add(scene, archive_entry) return archive_entry else: # layered history layered_entry = LayeredArchiveEntry(**entry.model_dump()) - scene.layered_history[entry.layer - 1][entry.index] = layered_entry.model_dump(exclude_none=True) + scene.layered_history[entry.layer - 1][entry.index] = layered_entry.model_dump( + exclude_none=True + ) return layered_entry - async def regenerate_history_entry( - scene: "Scene", - entry: HistoryEntry, + scene: "Scene", + entry: HistoryEntry, generation_options: GenerationOptions | None = None, ) -> LayeredArchiveEntry | ArchiveEntry: """ Regenerates a history entry in the scene's archived history """ - + summarizer = get_agent("summarizer") if entry.start is None or entry.end is None: # entries that dont defien a start and end are not regeneratable raise UnregeneratableEntryError("No start or end") entries = collect_source_entries(scene, entry) - + if not entries: raise UnregeneratableEntryError("No entries") - + try: - archive_entry: ArchiveEntry | LayeredArchiveEntry = resolve_history_entry(scene, entry) + archive_entry: ArchiveEntry | LayeredArchiveEntry = resolve_history_entry( + scene, entry + ) except IndexError: raise UnregeneratableEntryError("Entry not found") - + summarized = entry.text - + if isinstance(archive_entry, LayeredArchiveEntry): new_archive_entries = await summarizer.summarize_entries_to_layered_history( [entry.model_dump() for entry in entries], @@ -426,27 +473,28 @@ async def regenerate_history_entry( entry.end, generation_options=generation_options, ) - + if not new_archive_entries: raise UnregeneratableEntryError("Summarization produced no output") - + # if there is more than one entry, merge into first entry summarized = "\n\n".join(entry.text for entry in new_archive_entries) - + elif isinstance(archive_entry, ArchiveEntry): summarized = await summarizer.summarize( "\n".join(map(str, entries)), extra_context=await summarizer.previous_summaries(archive_entry), generation_options=generation_options, ) - + entry.text = summarized - + await update_history_entry(scene, entry) - + return entry -async def reimport_history(scene: "Scene", emit_status:bool = True): + +async def reimport_history(scene: "Scene", emit_status: bool = True): """ Reimports the history from the memory agent """ @@ -463,26 +511,28 @@ async def reimport_history(scene: "Scene", emit_status:bool = True): finally: if emit_status: emit("status", message="History reimported", status="success") - + async def validate_history(scene: "Scene") -> bool: - archived_history = scene.archived_history layered_history = scene.layered_history - + # if archived_history does not have memory_id set, we need to ensure # they are set and reimport to the memory agent - + any_missing_memory_id = any(entry.get("id") is None for entry in archived_history) - + invalid = any_missing_memory_id - + if invalid: - log.warning("History is invalid, fixing and reimporting", any_missing_memory_id=any_missing_memory_id) + log.warning( + "History is invalid, fixing and reimporting", + any_missing_memory_id=any_missing_memory_id, + ) await purge_all_history_from_memory() - + _archived_history = [] - + for entry in archived_history: try: _archived_history.append( @@ -492,25 +542,30 @@ async def validate_history(scene: "Scene") -> bool: log.error("Error validating history entry", error=e) log.error("Invalid entry", entry=entry) continue - + scene.archived_history = _archived_history - + # always send the archive_add signal for all entries # this ensures the entries are up to date in the memory database for entry in scene.archived_history: await emit_archive_add(scene, ArchiveEntry(**entry)) - + for layer_index, layer in enumerate(layered_history): for entry_index, entry in enumerate(layer): if not entry.get("id"): - log.warning("Layered history entry is missing id, generating one", layer=layer_index, index=entry_index) + log.warning( + "Layered history entry is missing id, generating one", + layer=layer_index, + index=entry_index, + ) entry["id"] = str(uuid.uuid4())[:8] # these entries also have their `end` value incorrectly offset by -1 so we need to fix it if entry.get("end") is not None: entry["end"] += 1 - + return not invalid + async def add_history_entry(scene: "Scene", text: str, offset: str) -> ArchiveEntry: """ Inserts a manual history entry into the base (archived) history. @@ -526,7 +581,7 @@ async def add_history_entry(scene: "Scene", text: str, offset: str) -> ArchiveEn Raises: ValueError: If the entry would not be older than the first summarized archive entry or if no summarized entry exists. """ - + is_first_entry = len(scene.archived_history) == 0 if is_first_entry: @@ -549,17 +604,27 @@ async def add_history_entry(scene: "Scene", text: str, offset: str) -> ArchiveEn break # Parse and convert to timedelta for arithmetic - scene_td = duration_to_timedelta(isodate.parse_duration(scene.ts)) - offset_td = duration_to_timedelta(isodate.parse_duration(offset)) + scene_td = duration_to_timedelta(isodate.parse_duration(scene.ts)) + offset_td = duration_to_timedelta(isodate.parse_duration(offset)) new_ts_td: datetime.timedelta = scene_td - offset_td - - log.debug("add_history_entry", is_first_entry=is_first_entry, scene_ts=scene.ts, offset=offset, scene_td=scene_td, offset_td=offset_td, new_ts_td=new_ts_td) + + log.debug( + "add_history_entry", + is_first_entry=is_first_entry, + scene_ts=scene.ts, + offset=offset, + scene_td=scene_td, + offset_td=offset_td, + new_ts_td=new_ts_td, + ) # If offset predates the current scene start, shift timeline earlier so # that the *relative* distance between existing events is preserved. if new_ts_td.total_seconds() < 0: - log.debug("offset is before scene start, shifting timeline", new_ts_td=new_ts_td) + log.debug( + "offset is before scene start, shifting timeline", new_ts_td=new_ts_td + ) # Amount we must shift the whole timeline forward so that the new # entry can be placed at PT0S. This is the *earliness* gap between # the requested offset and the current earliest timestamp. @@ -567,7 +632,7 @@ async def add_history_entry(scene: "Scene", text: str, offset: str) -> ArchiveEn # Since we already have the timedeltas, we can compute this directly shift_td = offset_td - scene_td # This will be positive shift_iso = isodate.duration_isoformat(shift_td) - + log.debug("shift_iso", shift_iso=shift_iso) # Shift everything forward by the calculated amount so that the @@ -576,14 +641,17 @@ async def add_history_entry(scene: "Scene", text: str, offset: str) -> ArchiveEn # After shifting, the new entry will sit at PT0S new_ts_td = datetime.timedelta(seconds=0) - - + if first_summary is not None: - first_summary_td = duration_to_timedelta(isodate.parse_duration(first_summary["ts"])) + first_summary_td = duration_to_timedelta( + isodate.parse_duration(first_summary["ts"]) + ) # New entry must be OLDER (i.e. smaller duration) than the first summary entry. if new_ts_td >= first_summary_td: - raise ValueError("New entry must be older than the first summarized history entry.") + raise ValueError( + "New entry must be older than the first summarized history entry." + ) # Build ArchiveEntry new_ts_str = isodate.duration_isoformat(new_ts_td) @@ -593,12 +661,16 @@ async def add_history_entry(scene: "Scene", text: str, offset: str) -> ArchiveEn inserted = False for idx, existing in enumerate(scene.archived_history): try: - existing_ts_td = duration_to_timedelta(isodate.parse_duration(existing.get("ts", "PT0S"))) + existing_ts_td = duration_to_timedelta( + isodate.parse_duration(existing.get("ts", "PT0S")) + ) except Exception: continue if new_ts_td < existing_ts_td: - scene.archived_history.insert(idx, archive_entry.model_dump(exclude_none=True)) + scene.archived_history.insert( + idx, archive_entry.model_dump(exclude_none=True) + ) inserted = True break @@ -611,11 +683,12 @@ async def add_history_entry(scene: "Scene", text: str, offset: str) -> ArchiveEn scene.sync_time() except Exception as e: log.error("add_history_entry.sync_time", error=e) - + await reimport_history(scene) return archive_entry + async def delete_history_entry(scene: "Scene", entry: HistoryEntry) -> ArchiveEntry: """ Deletes a manual base-layer history entry from the scene archives and removes it from the memory store. @@ -640,7 +713,7 @@ async def delete_history_entry(scene: "Scene", entry: HistoryEntry) -> ArchiveEn if existing.get("id") == entry.id: remove_idx = idx break - + is_oldest_entry = remove_idx == 0 if remove_idx is None: @@ -648,14 +721,15 @@ async def delete_history_entry(scene: "Scene", entry: HistoryEntry) -> ArchiveEn removed_raw = scene.archived_history.pop(remove_idx) removed_entry = ArchiveEntry(**removed_raw) - + if is_oldest_entry: # The removed first entry is always at 0s. We therefore need to shift # the timeline by the timestamp of **what is now** the first entry so # that it becomes ``PT0S``. shift_iso = ( (scene.archived_history[0].get("ts") or "PT0S") - if scene.archived_history else "PT0S" + if scene.archived_history + else "PT0S" ) # Apply the negative shift to the entire scene timeline. shift_scene_timeline(scene, f"-{shift_iso}") @@ -679,7 +753,11 @@ def _shift_entry_ts(entry: dict, shift_iso: str): entry[key] = iso8601_add(entry[key], shift_iso, clamp_non_negative=True) except Exception as e: # pragma: no cover – defensive only log.error( - "shift_entry_ts", error=e, key=key, value=entry.get(key), shift_iso=shift_iso + "shift_entry_ts", + error=e, + key=key, + value=entry.get(key), + shift_iso=shift_iso, ) @@ -703,7 +781,12 @@ def shift_scene_timeline(scene: "Scene", shift_iso: str): try: scene.ts = iso8601_add(scene.ts, shift_iso, clamp_non_negative=True) except Exception as e: # pragma: no cover – defensive only - log.error("shift_scene_timeline.scene_ts", error=e, scene_ts=scene.ts, shift_iso=shift_iso) + log.error( + "shift_scene_timeline.scene_ts", + error=e, + scene_ts=scene.ts, + shift_iso=shift_iso, + ) # 2) shift archived_history entries for entry in scene.archived_history: @@ -712,4 +795,4 @@ def shift_scene_timeline(scene: "Scene", shift_iso: str): # 3) shift layered history entries for layer in scene.layered_history: for entry in layer: - _shift_entry_ts(entry, shift_iso) \ No newline at end of file + _shift_entry_ts(entry, shift_iso) diff --git a/src/talemate/instance.py b/src/talemate/instance.py index 1c14787b..0b1699c5 100644 --- a/src/talemate/instance.py +++ b/src/talemate/instance.py @@ -50,7 +50,7 @@ def get_client(name: str, *create_args, **create_kwargs): if client: if create_kwargs: if system_prompts: - client.set_system_prompts(system_prompts) + client.set_system_prompts(system_prompts) client.reconfigure(**create_kwargs) return client @@ -58,10 +58,10 @@ def get_client(name: str, *create_args, **create_kwargs): typ = create_kwargs.get("type") cls = clients.get_client_class(typ) client = cls(name=name, *create_args, **create_kwargs) - + if system_prompts: client.set_system_prompts(system_prompts) - + set_client(name, client) return client diff --git a/src/talemate/load.py b/src/talemate/load.py index afc339f3..aaf7e0b9 100644 --- a/src/talemate/load.py +++ b/src/talemate/load.py @@ -4,7 +4,6 @@ import os import structlog -import talemate.events as events import talemate.instance as instance from talemate import Actor, Character, Player, Scene from talemate.instance import get_agent @@ -88,22 +87,20 @@ async def load_scene(scene, file_path, conv_client, reset: bool = False): def identify_import_spec(data: dict) -> ImportSpec: - if data.get("spec") == "chara_card_v3": return ImportSpec.chara_card_v3 - + if data.get("spec") == "chara_card_v2": return ImportSpec.chara_card_v2 if data.get("spec") == "chara_card_v1": return ImportSpec.chara_card_v1 - - + if "first_mes" in data: # original chara card didnt specify a spec, # if the first_mes key exists, we can assume it's a v0 chara card return ImportSpec.chara_card_v0 - + if "first_mes" in data.get("data", {}): # this can also serve as a fallback for future chara card versions # as they are supposed to be backwards compatible @@ -117,7 +114,7 @@ async def load_scene_from_character_card(scene, file_path): """ Load a character card (tavern etc.) from the given file path. """ - + director = get_agent("director") LOADING_STEPS = 5 if director.auto_direct_enabled: @@ -129,9 +126,9 @@ async def load_scene_from_character_card(scene, file_path): file_ext = os.path.splitext(file_path)[1].lower() image_format = file_ext.lstrip(".") image = False - + await handle_no_player_character(scene) - + # If a json file is found, use Character.load_from_json instead if file_ext == ".json": character = load_character_from_json(file_path) @@ -257,7 +254,6 @@ async def load_scene_from_data( loading_status = LoadingStatus(1) reset_message_id() - memory = scene.get_helper("memory").agent scene.description = scene_data.get("description", "") @@ -273,7 +269,7 @@ async def load_scene_from_data( scene.writing_style_template = scene_data.get("writing_style_template", "") scene.nodes_filename = scene_data.get("nodes_filename", "") scene.creative_nodes_filename = scene_data.get("creative_nodes_filename", "") - + import_scene_node_definitions(scene) if not reset: @@ -329,7 +325,12 @@ async def load_scene_from_data( await scene.add_actor(actor) # if there is nio player character, add the default player character - await handle_no_player_character(scene, add_default_character=scene.config.get("game", {}).get("general", {}).get("add_default_character", True)) + await handle_no_player_character( + scene, + add_default_character=scene.config.get("game", {}) + .get("general", {}) + .get("add_default_character", True), + ) # the scene has been saved before (since we just loaded it), so we set the saved flag to True # as long as the scene has a memory_id. @@ -389,27 +390,29 @@ async def transfer_character(scene, scene_json_path, character_name): return scene -async def handle_no_player_character(scene: Scene, add_default_character: bool = True) -> None: +async def handle_no_player_character( + scene: Scene, add_default_character: bool = True +) -> None: """ Handle the case where there is no player character in the scene. """ - + existing_player = scene.get_player_character() - + if existing_player: return - + if add_default_character: player = default_player_character() else: player = None - + if not player: # force scene into creative mode scene.environment = "creative" log.warning("No player character found, forcing scene into creative mode") return - + await scene.add_actor(player) @@ -422,7 +425,7 @@ def load_character_from_image(image_path: str, file_format: str) -> Character: """ metadata = extract_metadata(image_path, file_format) spec = identify_import_spec(metadata) - + log.debug("load_character_from_image", spec=spec) if spec == ImportSpec.chara_card_v2 or spec == ImportSpec.chara_card_v3: @@ -520,11 +523,11 @@ def default_player_character() -> Player | None: load_config().get("game", {}).get("default_player_character", {}) ) name = default_player_character.get("name") - + if not name: # We don't have a valid default player character, so we return None return None - + color = default_player_character.get("color", "cyan") description = default_player_character.get("description", "") @@ -552,7 +555,6 @@ def _load_history(history): return _history - def _prepare_history(entry): typ = entry.pop("typ", "scene_message") entry.pop("id", None) @@ -563,12 +565,12 @@ def _prepare_history(entry): cls = MESSAGES.get(typ, SceneMessage) msg = cls(**entry) - + if isinstance(msg, (NarratorMessage, ReinforcementMessage)): msg = msg.migrate_source_to_meta() elif isinstance(msg, DirectorMessage): msg = msg.migrate_message_to_meta() - + return msg diff --git a/src/talemate/prompts/__init__.py b/src/talemate/prompts/__init__.py index 190c60d3..6d896a09 100644 --- a/src/talemate/prompts/__init__.py +++ b/src/talemate/prompts/__init__.py @@ -1 +1 @@ -from .base import LoopedPrompt, Prompt +from .base import LoopedPrompt, Prompt # noqa: F401 diff --git a/src/talemate/prompts/base.py b/src/talemate/prompts/base.py index 67025cc6..cb6c9428 100644 --- a/src/talemate/prompts/base.py +++ b/src/talemate/prompts/base.py @@ -2,7 +2,7 @@ Base prompt loader The idea is to be able to specify prompts for the various agents in a way that is -changeable and extensible. +changeable and extensible. """ import asyncio @@ -53,12 +53,14 @@ log = structlog.get_logger("talemate") prepended_template_dirs = ContextVar("prepended_template_dirs", default=[]) + class PydanticJsonEncoder(json.JSONEncoder): def default(self, obj): if hasattr(obj, "model_dump"): return obj.model_dump() return super().default(obj) + class PrependTemplateDirectories: def __init__(self, prepend_dir: list): if isinstance(prepend_dir, str): @@ -208,7 +210,6 @@ class LoopedPrompt: class JoinableList(list): - def join(self, separator: str = "\n"): return separator.join(self) @@ -230,8 +231,8 @@ class Prompt: # prompt text prompt: str = None - - # template text + + # template text template: str | None = None # prompt variables @@ -278,7 +279,7 @@ class Prompt: return prompt @classmethod - def from_text(cls, text:str, vars: dict = None): + def from_text(cls, text: str, vars: dict = None): return cls( uid="", agent_type="", @@ -326,12 +327,8 @@ class Prompt: os.path.join( dir_path, "..", "..", "..", "templates", "prompts", self.agent_type ), - os.path.join( - dir_path, "..", "..", "..", "templates", "prompts", "common" - ), - os.path.join( - dir_path, "..", "..", "..", "templates", "modules" - ), + os.path.join(dir_path, "..", "..", "..", "templates", "prompts", "common"), + os.path.join(dir_path, "..", "..", "..", "templates", "modules"), os.path.join(dir_path, "templates", self.agent_type), os.path.join(dir_path, "templates", "common"), ] @@ -377,7 +374,9 @@ class Prompt: "thematic_generator": thematic_generators.ThematicGenerator(), "regeneration_context": regeneration_context.get(), "active_agent": active_agent.get(), - "agent_context_state": active_agent.get().state if active_agent.get() else {}, + "agent_context_state": active_agent.get().state + if active_agent.get() + else {}, } env.globals["render_template"] = self.render_template @@ -414,12 +413,14 @@ class Prompt: env.globals["make_list"] = lambda: JoinableList() env.globals["make_dict"] = lambda: {} env.globals["join"] = lambda x, y: y.join(x) - env.globals["data_format_type"] = lambda: getattr(self.client, "data_format", None) or self.data_format_type + env.globals["data_format_type"] = ( + lambda: getattr(self.client, "data_format", None) or self.data_format_type + ) env.globals["count_tokens"] = lambda x: count_tokens( dedupe_string(x, debug=False) ) env.globals["print"] = lambda x: print(x) - env.globals["json"]= lambda x: json.dumps(x, indent=2, cls=PydanticJsonEncoder) + env.globals["json"] = lambda x: json.dumps(x, indent=2, cls=PydanticJsonEncoder) env.globals["emit_status"] = self.emit_status env.globals["emit_system"] = lambda status, message: emit( "system", status=status, message=message @@ -432,7 +433,6 @@ class Prompt: env.filters["condensed"] = condensed env.filters["no_chapters"] = no_chapters ctx.update(self.vars) - if "decensor" not in ctx: ctx["decensor"] = False @@ -448,7 +448,7 @@ class Prompt: # Render the template with the prompt variables self.eval_context = {} - #self.dedupe_enabled = True + # self.dedupe_enabled = True try: self.prompt = template.render(ctx) if not sectioning_handler: @@ -549,7 +549,7 @@ class Prompt: return "\n".join( [ f"Question: {query}", - f"Answer: " + "Answer: " + loop.run_until_complete( narrator.narrate_query( query, at_the_end=at_the_end, as_narrative=as_narrative @@ -574,13 +574,15 @@ class Prompt: if not as_question_answer: return loop.run_until_complete( - world_state.analyze_text_and_answer_question(text, query, response_length=10 if short else 512) + world_state.analyze_text_and_answer_question( + text, query, response_length=10 if short else 512 + ) ) return "\n".join( [ f"Question: {query}", - f"Answer: " + "Answer: " + loop.run_until_complete( world_state.analyze_text_and_answer_question( text, query, response_length=10 if short else 512 @@ -618,7 +620,7 @@ class Prompt: ) ) - def instruct_text(self, instruction: str, text: str, as_list:bool=False): + def instruct_text(self, instruction: str, text: str, as_list: bool = False): loop = asyncio.get_event_loop() world_state = instance.get_agent("world_state") instruction = instruction.format(**self.vars) @@ -629,7 +631,7 @@ class Prompt: response = loop.run_until_complete( world_state.analyze_and_follow_instruction(text, instruction) ) - + if as_list: return extract_list(response) else: @@ -644,9 +646,8 @@ class Prompt: return loop.run_until_complete( world_state.analyze_text_and_extract_context("\n".join(lines), goal=goal) ) - - def agent_config(self, config_path: str): + def agent_config(self, config_path: str): try: agent_name, action_name, config_name = config_path.split(".") agent = instance.get_agent(agent_name) @@ -673,37 +674,33 @@ class Prompt: return "" return iso8601_diff_to_human(iso8601_time, scene.ts) - def text_to_chunks(self, text:str, chunk_size:int=512) -> list[str]: + def text_to_chunks(self, text: str, chunk_size: int = 512) -> list[str]: """ Takes a text string and splits it into chunks based length of the text. - + Arguments: - + - text: The text to split into chunks. - chunk_size: number of characters in each chunk. """ - + chunks = [] - + for i, line in enumerate(text.split("\n")): - # dont push empty lines into empty chunks if not line.strip() and (not chunks or not chunks[-1]): continue - + if not chunks: chunks.append([line]) continue - + if len("\n".join(chunks[-1])) + len(line) < chunk_size: chunks[-1].append(line) else: chunks.append([line]) - return ["\n\n".join(chunk) for chunk in chunks] - - def set_prepared_response(self, response: str, prepend: str = ""): """ @@ -749,52 +746,68 @@ class Prompt: ): """ Prepares for a data response in the client's preferred format (YAML or JSON) - + Args: initial_object (dict): The data structure to serialize instruction (str): Optional instruction/schema comment cutoff (int): Number of lines to trim from the end """ # Always use client data format if available - data_format_type = getattr(self.client, "data_format", None) or self.data_format_type - + data_format_type = ( + getattr(self.client, "data_format", None) or self.data_format_type + ) + self.data_format_type = data_format_type self.data_response = True - + if data_format_type == "yaml": if yaml is None: - raise ImportError("PyYAML is required for YAML support. Please install it with 'pip install pyyaml'.") - + raise ImportError( + "PyYAML is required for YAML support. Please install it with 'pip install pyyaml'." + ) + # Serialize to YAML - prepared_response = yaml.safe_dump(initial_object, sort_keys=False).split("\n") - + prepared_response = yaml.safe_dump(initial_object, sort_keys=False).split( + "\n" + ) + # For list structures, ensure we stop after the key with a colon - if isinstance(initial_object, dict) and any(isinstance(v, list) for v in initial_object.values()): + if isinstance(initial_object, dict) and any( + isinstance(v, list) for v in initial_object.values() + ): # Find the first key that has a list value and stop there for i, line in enumerate(prepared_response): - if line.strip().endswith(':'): # Found a key that might have a list + if line.strip().endswith(":"): # Found a key that might have a list # Look ahead to see if next line has a dash (indicating it's a list) - if i+1 < len(prepared_response) and prepared_response[i+1].strip().startswith('- '): + if i + 1 < len(prepared_response) and prepared_response[ + i + 1 + ].strip().startswith("- "): # Keep only up to the key with colon, drop the list items - prepared_response = prepared_response[:i+1] + prepared_response = prepared_response[: i + 1] break # For nested dictionary structures, keep only the top-level keys - elif isinstance(initial_object, dict) and any(isinstance(v, dict) for v in initial_object.values()): + elif isinstance(initial_object, dict) and any( + isinstance(v, dict) for v in initial_object.values() + ): # Find keys that have dictionary values for i, line in enumerate(prepared_response): - if line.strip().endswith(':'): # Found a key that might have a nested dict + if line.strip().endswith( + ":" + ): # Found a key that might have a nested dict # Look ahead to see if next line is indented (indicating nested structure) - if i+1 < len(prepared_response) and prepared_response[i+1].startswith(' '): + if i + 1 < len(prepared_response) and prepared_response[ + i + 1 + ].startswith(" "): # Keep only up to the key with colon, drop the nested content - prepared_response = prepared_response[:i+1] + prepared_response = prepared_response[: i + 1] break elif cutoff > 0: # For other structures, just remove last lines prepared_response = prepared_response[:-cutoff] - + if instruction: prepared_response.insert(0, f"# {instruction}") - + cleaned = "\n".join(prepared_response) # Wrap in markdown code block for YAML, but do not close the code block # Add an extra newline to ensure the model's response starts on a new line @@ -810,12 +823,16 @@ class Prompt: cleaned = re.sub(r"\s+", " ", cleaned) return self.set_prepared_response(cleaned) - def set_json_response(self, initial_object: dict, instruction: str = "", cutoff: int = 3): + def set_json_response( + self, initial_object: dict, instruction: str = "", cutoff: int = 3 + ): """ Prepares for a json response """ self.data_format_type = "json" - return self.set_data_response(initial_object, instruction=instruction, cutoff=cutoff) + return self.set_data_response( + initial_object, instruction=instruction, cutoff=cutoff + ) def set_question_eval( self, question: str, trigger: str, counter: str, weight: float = 1.0 @@ -839,8 +856,10 @@ class Prompt: Parse a YAML response from the LLM. """ if yaml is None: - raise ImportError("PyYAML is required for YAML support. Please install it with 'pip install pyyaml'.") - + raise ImportError( + "PyYAML is required for YAML support. Please install it with 'pip install pyyaml'." + ) + # Extract YAML from markdown code blocks if "```yaml" in response and "```" in response.split("```yaml", 1)[1]: yaml_block = response.split("```yaml", 1)[1].split("```", 1)[0] @@ -849,7 +868,7 @@ class Prompt: yaml_block = response.split("```", 1)[1].split("```", 1)[0] else: yaml_block = response - + try: return yaml.safe_load(yaml_block) except Exception as e: @@ -858,7 +877,7 @@ class Prompt: f"{self.name} - Error parsing YAML response: {e}", model_name=self.client.model_name if self.client else "unknown", ) - + async def parse_data_response(self, response): """ Parse response based on configured data format @@ -870,7 +889,7 @@ class Prompt: return await self.parse_yaml_response(response) else: raise ValueError(f"Unsupported data format: {self.data_format_type}") - + async def parse_json_response(self, response, ai_fix: bool = True): # strip comments try: @@ -882,7 +901,7 @@ class Prompt: try: response = json.loads(response) return response - except json.decoder.JSONDecodeError as e: + except json.decoder.JSONDecodeError: pass response = response.replace("True", "true").replace("False", "false") response = "\n".join( @@ -1017,27 +1036,35 @@ class Prompt: pad = " " if self.pad_prepended_response else "" response = self.prepared_response.rstrip() + pad + response.strip() else: - format_type = getattr(self.client, "data_format", None) or self.data_format_type - + format_type = ( + getattr(self.client, "data_format", None) or self.data_format_type + ) + json_start = response.lstrip().startswith("{") yaml_block = response.lstrip().startswith("```yaml") - + # If response doesn't start with expected format markers, prepend the prepared response - if (format_type == "json" and not json_start) or (format_type == "yaml" and not yaml_block): + if (format_type == "json" and not json_start) or ( + format_type == "yaml" and not yaml_block + ): pad = " " if self.pad_prepended_response else "" if format_type == "yaml": if self.client.can_be_coerced: response = self.prepared_response + response.rstrip() else: - response = self.prepared_response.rstrip() + "\n " + response.rstrip() + response = ( + self.prepared_response.rstrip() + "\n " + response.rstrip() + ) else: response = self.prepared_response.rstrip() + pad + response.strip() - + if self.eval_response: return await self.evaluate(response) if self.data_response: - log.debug("data_response", format_type=self.data_format_type, response=response) + log.debug( + "data_response", format_type=self.data_format_type, response=response + ) return response, await self.parse_data_response(response) response = clean_response(response) @@ -1189,7 +1216,7 @@ def titles_prompt_sectioning(prompt: Prompt) -> str: def html_prompt_sectioning(prompt: Prompt) -> str: return _prompt_sectioning( prompt, - lambda section_name: f"<{section_name.capitalize().replace(' ','')}>", - lambda section_name: f"", + lambda section_name: f"<{section_name.capitalize().replace(' ', '')}>", + lambda section_name: f"", strip_empty_lines=True, ) diff --git a/src/talemate/prompts/overrides.py b/src/talemate/prompts/overrides.py index 340a01d9..cce69d82 100644 --- a/src/talemate/prompts/overrides.py +++ b/src/talemate/prompts/overrides.py @@ -5,6 +5,7 @@ from typing import List, Optional from talemate.prompts.base import prepended_template_dirs + @dataclass class TemplateOverride: template_name: str @@ -13,43 +14,44 @@ class TemplateOverride: age_difference: str # Human readable time difference override_newer: bool + def get_template_overrides(agent_type: str) -> List[TemplateOverride]: """ Identifies template files that are being overridden and calculates age differences between override and default templates. - + Args: agent_type (str): The type of agent to check templates for - + Returns: List[TemplateOverride]: List of template overrides with their details """ # Get the directory of the current file (assuming this is in the same dir as base_prompt.py) dir_path = os.path.dirname(os.path.realpath(__file__)) - + # Define template directories as in the Prompt class default_template_dirs = [ os.path.join(dir_path, "..", "..", "..", "templates", "prompts", agent_type), os.path.join(dir_path, "templates", agent_type), ] - + template_dirs = prepended_template_dirs.get() + default_template_dirs overrides = [] - + # Helper function to get file modification time def get_file_mtime(filepath: str) -> Optional[datetime]: try: return datetime.fromtimestamp(os.path.getmtime(filepath)) except (OSError, ValueError): return None - + # Helper function to calculate human readable time difference def get_time_difference(time1: datetime, time2: datetime) -> str: diff = abs(time1 - time2) days = diff.days hours = diff.seconds // 3600 minutes = (diff.seconds % 3600) // 60 - + parts = [] if days > 0: parts.append(f"{days} days") @@ -57,53 +59,55 @@ def get_template_overrides(agent_type: str) -> List[TemplateOverride]: parts.append(f"{hours} hours") elif minutes > 0: parts.append(f"{minutes} minutes") - + return ", ".join(parts) if parts else "less than a minute" # Build a map of template names to their locations template_locations = {} - + for template_dir in template_dirs: if not os.path.exists(template_dir): continue - + for root, _, files in os.walk(template_dir): for filename in files: - if not filename.endswith('.jinja2'): + if not filename.endswith(".jinja2"): continue - + filepath = os.path.join(root, filename) rel_path = os.path.relpath(root, template_dir) template_name = os.path.join(rel_path, filename) - + if template_name not in template_locations: template_locations[template_name] = [] template_locations[template_name].append(filepath) - + # Analyze overrides for template_name, locations in template_locations.items(): if len(locations) < 2: continue - + # The first location is the override, the last is the default override_path = locations[0] default_path = locations[-1] - + override_time = get_file_mtime(override_path) default_time = get_file_mtime(default_path) - + if not override_time or not default_time: continue - + age_diff = get_time_difference(default_time, override_time) override_newer = override_time > default_time - - overrides.append(TemplateOverride( - template_name=template_name, - override_path=override_path, - default_path=default_path, - age_difference=age_diff, - override_newer=override_newer - )) - - return overrides \ No newline at end of file + + overrides.append( + TemplateOverride( + template_name=template_name, + override_path=override_path, + default_path=default_path, + age_difference=age_diff, + override_newer=override_newer, + ) + ) + + return overrides diff --git a/src/talemate/regenerate.py b/src/talemate/regenerate.py index 7856add7..3e125164 100644 --- a/src/talemate/regenerate.py +++ b/src/talemate/regenerate.py @@ -9,8 +9,6 @@ from talemate.scene_message import ( SceneMessage, CharacterMessage, NarratorMessage, - DirectorMessage, - TimePassageMessage, ReinforcementMessage, ContextInvestigationMessage, ) @@ -27,81 +25,98 @@ __all__ = [ log = structlog.get_logger("talemate.regenerate") -async def regenerate_character_message(message: CharacterMessage, scene:"Scene") -> CharacterMessage: - character:"Character | None" = scene.get_character(message.character_name) - +async def regenerate_character_message( + message: CharacterMessage, scene: "Scene" +) -> CharacterMessage: + character: "Character | None" = scene.get_character(message.character_name) + if not character: - log.error("regenerate_character_message: Could not find character", message=message) + log.error( + "regenerate_character_message: Could not find character", message=message + ) return message - + agent = get_agent("conversation") - + if message.source == "player" and not message.from_choice: - log.warning("regenerate_character_message: Static user message, no regeneration possible", message=message) + log.warning( + "regenerate_character_message: Static user message, no regeneration possible", + message=message, + ) return - + messages = await agent.converse(character.actor, instruction=message.from_choice) - + for message in messages: scene.push_history(message) emit("character", message=message, character=character) - - return messages - -async def regenerate_message(message: SceneMessage, scene:"Scene") -> list[SceneMessage] | None: + return messages + + +async def regenerate_message( + message: SceneMessage, scene: "Scene" +) -> list[SceneMessage] | None: """ Will regenerate the message, using the meta information """ - + if isinstance(message, CharacterMessage): # character messages need specific handling messages = await regenerate_character_message(message, scene) else: # all other message types - + try: agent = get_agent(message.meta.get("agent")) except Exception as e: - log.error(f"regenerate_message: Could not find agent", message=message, error=e) + log.error( + "regenerate_message: Could not find agent", message=message, error=e + ) return - + if not agent: - log.error(f"regenerate_message: Could not find agent", message=message) + log.error("regenerate_message: Could not find agent", message=message) return - + function_name = message.meta.get("function") fn = getattr(agent, function_name, None) - + if not fn: - log.error(f"regenerate_message: Could not find agent function", message=message) + log.error( + "regenerate_message: Could not find agent function", message=message + ) return arguments = message.meta.get("arguments", {}).copy() - + # if `character` is set and a string, convert it to a Character if arguments.get("character") and isinstance(arguments.get("character"), str): arguments["character"] = scene.get_character(arguments.get("character")) - - log.debug(f"regenerate_message: Calling agent function", function=function_name, arguments=arguments) - + + log.debug( + "regenerate_message: Calling agent function", + function=function_name, + arguments=arguments, + ) + new_message = await fn(**arguments) - + if not new_message: - log.error(f"regenerate_message: No new message generated", message=message) + log.error("regenerate_message: No new message generated", message=message) return - + if isinstance(new_message, str): new_message = message.__class__(new_message) new_message.meta = message.meta.copy() - + if isinstance(message, ContextInvestigationMessage): new_message.sub_type = message.sub_type - + if not isinstance(new_message, (ReinforcementMessage)): scene.push_history(new_message) emit(new_message.typ, message=new_message) - + messages = [new_message] for message in messages: @@ -109,15 +124,17 @@ async def regenerate_message(message: SceneMessage, scene:"Scene") -> list[Scene events.RegenerateGeneration( scene=scene, message=message, - character=scene.get_character(message.character_name) if isinstance(message, CharacterMessage) else None, + character=scene.get_character(message.character_name) + if isinstance(message, CharacterMessage) + else None, event_type=f"regenerate.msg.{message.typ}", ) ) - + return messages -async def regenerate(scene:"Scene", idx:int=-1) -> list[SceneMessage]: +async def regenerate(scene: "Scene", idx: int = -1) -> list[SceneMessage]: """ Regenerate the most recent AI response, remove their previous message from the history, and call talk() for the most recent AI Character. @@ -127,7 +144,7 @@ async def regenerate(scene:"Scene", idx:int=-1) -> list[SceneMessage]: message = scene.history[idx] except IndexError: return - + regenerated_messages = [] # while message type is ReinforcementMessage, keep going back in history @@ -156,28 +173,29 @@ async def regenerate(scene:"Scene", idx:int=-1) -> list[SceneMessage]: if current_regeneration_context: current_regeneration_context.message = message.message - if not isinstance(message, (CharacterMessage, NarratorMessage, ContextInvestigationMessage)): + if not isinstance( + message, (CharacterMessage, NarratorMessage, ContextInvestigationMessage) + ): log.warning("Cannot regenerate message", message=message) return regenerated_messages - + scene.history.pop() emit("remove_message", "", id=message.id) new_messages = await regenerate_message(message, scene) - + if not new_messages: log.error("No new messages generated", message=message) scene.push_history(message) for message in reversed(popped_reinforcement_messages): scene.push_history(message) return regenerated_messages - - + if new_messages: regenerated_messages.extend(new_messages) - + for message in popped_reinforcement_messages: new_messages = await regenerate_message(message, scene) if new_messages: regenerated_messages.extend(new_messages) - return regenerated_messages \ No newline at end of file + return regenerated_messages diff --git a/src/talemate/save.py b/src/talemate/save.py index ef376e86..bb9cf3b9 100644 --- a/src/talemate/save.py +++ b/src/talemate/save.py @@ -13,16 +13,18 @@ if TYPE_CHECKING: log = structlog.get_logger("talemate.save") + def combine_paths(absolute, relative): # Split paths into components rel_parts = os.path.normpath(relative).split(os.sep) - + # Get just the filename/last component from relative path rel_end = rel_parts[-1] - + # Join absolute path with just the final component return os.path.join(absolute, rel_end) + class SceneEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, SceneMessage): @@ -30,11 +32,12 @@ class SceneEncoder(json.JSONEncoder): return super().default(obj) -async def save_node_module(scene:"Scene", graph:"Graph", filename:str = None, set_as_main:bool = False) -> str: +async def save_node_module( + scene: "Scene", graph: "Graph", filename: str = None, set_as_main: bool = False +) -> str: if not os.path.exists(scene.nodes_dir): os.makedirs(scene.nodes_dir) - - + if isinstance(graph, SceneLoop) and set_as_main: scene.nodes_filename = filename or "scene-loop.json" log.debug("saving scene nodes", filename=scene.nodes_filepath) @@ -43,12 +46,12 @@ async def save_node_module(scene:"Scene", graph:"Graph", filename:str = None, se else: if not filename: raise ValueError("filename is required for non SceneLoop nodes") - + # filename make contain relative path # scenes.node_dir is the base path (absolute) - + save_to_path = combine_paths(scene.nodes_dir, filename) - + log.debug("saving nodes", filename=save_to_path) await save_graph(graph, save_to_path) return save_to_path diff --git a/src/talemate/scene/intent.py b/src/talemate/scene/intent.py index bef8781f..95e59d29 100644 --- a/src/talemate/scene/intent.py +++ b/src/talemate/scene/intent.py @@ -5,7 +5,7 @@ Intention of the story or a sub-scene. What are the expectations of the user and What is the overarching intention of the story? -This is probably an abstract description of what type of experiences will be relayed by the story through individual scenes. +This is probably an abstract description of what type of experiences will be relayed by the story through individual scenes. ## Individual Scene Intent @@ -25,25 +25,28 @@ from .schema import ( if TYPE_CHECKING: from talemate.tale_mate import Scene - + __all__ = [ "set_scene_phase", ] -async def set_scene_phase(scene:"Scene", scene_type_id:str, intent:str) -> ScenePhase: + +async def set_scene_phase( + scene: "Scene", scene_type_id: str, intent: str +) -> ScenePhase: """ Set the scene phase. """ - - scene_intent:SceneIntent = scene.intent_state - + + scene_intent: SceneIntent = scene.intent_state + if scene_type_id not in scene_intent.scene_types: raise ValueError(f"Invalid scene type: {scene_type_id}") - + scene_intent.phase = ScenePhase( - scene_type=scene_type_id, + scene_type=scene_type_id, intent=intent, - start=scene.history[-1].id if scene.history else 0 + start=scene.history[-1].id if scene.history else 0, ) - - return scene_intent.phase \ No newline at end of file + + return scene_intent.phase diff --git a/src/talemate/scene/schema.py b/src/talemate/scene/schema.py index b355e88f..37ac3d7b 100644 --- a/src/talemate/scene/schema.py +++ b/src/talemate/scene/schema.py @@ -1,19 +1,11 @@ -from typing import TYPE_CHECKING import pydantic from talemate.world_state import WorldState from talemate.game.state import GameState -if TYPE_CHECKING: - from talemate.tale_mate import Scene +__all__ = ["SceneType", "ScenePhase", "SceneIntent", "SceneState"] -__all__ = [ - 'SceneType', - 'ScenePhase', - 'SceneIntent', - 'SceneState' -] def make_default_types() -> list["SceneType"]: return { @@ -23,45 +15,50 @@ def make_default_types() -> list["SceneType"]: description="Freeform dialogue between one or more characters with occasional narration.", ) } - + + def make_default_phase() -> "ScenePhase": default_type = make_default_types().get("roleplay") - return ScenePhase( - scene_type=default_type.id - ) - + return ScenePhase(scene_type=default_type.id) + + class SceneType(pydantic.BaseModel): id: str name: str description: str instructions: str | None = None + class ScenePhase(pydantic.BaseModel): scene_type: str intent: str | None = None - + + class SceneIntent(pydantic.BaseModel): - scene_types: dict[str, SceneType] | None = pydantic.Field(default_factory=make_default_types) + scene_types: dict[str, SceneType] | None = pydantic.Field( + default_factory=make_default_types + ) intent: str | None = None phase: ScenePhase | None = pydantic.Field(default_factory=make_default_phase) start: int = 0 - + @property def current_scene_type(self) -> SceneType: return self.scene_types[self.phase.scene_type] - + @property def active(self) -> bool: - return (self.intent or self.phase) - + return self.intent or self.phase + def get_scene_type(self, scene_type_id: str) -> SceneType: return self.scene_types[scene_type_id] - + + class SceneState(pydantic.BaseModel): world_state: "WorldState | None" = None game_state: "GameState | None" = None agent_state: dict | None = None intent_state: SceneIntent | None = None - + def model_dump(self, **kwargs): - return super().model_dump(exclude_none=True) \ No newline at end of file + return super().model_dump(exclude_none=True) diff --git a/src/talemate/scene/state_editor.py b/src/talemate/scene/state_editor.py index 7b4e3d29..09a6976c 100644 --- a/src/talemate/scene/state_editor.py +++ b/src/talemate/scene/state_editor.py @@ -6,6 +6,7 @@ Allows the in-memory editing of a scene's states. - game state - intent state """ + from typing import TYPE_CHECKING from .schema import SceneState @@ -13,37 +14,35 @@ from .schema import SceneState if TYPE_CHECKING: from talemate.tale_mate import Scene -__all__ = [ - "SceneStateEditor" -] +__all__ = ["SceneStateEditor"] + class SceneStateEditor: - def __init__(self, scene: "Scene"): self.scene = scene - + def dump(self) -> dict: scene: "Scene" = self.scene self.state = SceneState( world_state=scene.world_state, game_state=scene.game_state, agent_state=scene.agent_state, - intent_state=scene.intent_state + intent_state=scene.intent_state, ) return self.state.model_dump() - + def load(self, data: dict): state: SceneState = SceneState(**data) scene: "Scene" = self.scene - + if "world_state" in data: scene.world_state = state.world_state - + if "game_state" in data: scene.game_state = state.game_state - + if "agent_state" in data: scene.agent_state.update(state.agent_state) - + if "intent_state" in data: - scene.intent_state = state.intent_state \ No newline at end of file + scene.intent_state = state.intent_state diff --git a/src/talemate/scene_assets.py b/src/talemate/scene_assets.py index 01518632..3ca25bbf 100644 --- a/src/talemate/scene_assets.py +++ b/src/talemate/scene_assets.py @@ -3,7 +3,7 @@ from __future__ import annotations import base64 import hashlib import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pydantic @@ -198,7 +198,7 @@ class SceneAssets: """ asset_path = self.asset_path(asset_id) - + if not asset_path: log.debug("asset_path not found", asset_id=asset_id) return None diff --git a/src/talemate/scene_message.py b/src/talemate/scene_message.py index f8c31027..40132655 100644 --- a/src/talemate/scene_message.py +++ b/src/talemate/scene_message.py @@ -54,7 +54,7 @@ class SceneMessage: # the source of the message (e.g. "ai", "progress_story", "director") source: str = "" - + meta: dict | None = None flags: Flags = Flags.NONE @@ -84,10 +84,10 @@ class SceneMessage: "source": self.source, "flags": int(self.flags), } - + if self.meta: rv["meta"] = self.meta - + return rv def __iter__(self): @@ -113,7 +113,7 @@ class SceneMessage: @property def hidden(self): return self.flags & Flags.HIDDEN - + @property def fingerprint(self) -> str: """ @@ -124,11 +124,11 @@ class SceneMessage: @property def source_agent(self) -> str | None: return (self.meta or {}).get("agent", None) - + @property def source_function(self) -> str | None: return (self.meta or {}).get("function", None) - + @property def source_arguments(self) -> dict: return (self.meta or {}).get("arguments", {}) @@ -147,19 +147,20 @@ class SceneMessage: if format == "movie_script": return self.message.rstrip("\n") + "\n" return self.message - + def set_source(self, agent: str, function: str, **kwargs): if not self.meta: self.meta = {} self.meta["agent"] = agent self.meta["function"] = function self.meta["arguments"] = kwargs - + def set_meta(self, **kwargs): if not self.meta: self.meta = {} self.meta.update(kwargs) + @dataclass class CharacterMessage(SceneMessage): typ = "character" @@ -180,7 +181,7 @@ class CharacterMessage(SceneMessage): @property def raw(self): return self.message.split(":", 1)[1].replace('"', "").replace("*", "").strip() - + @property def without_name(self) -> str: return self.message.split(":", 1)[1] @@ -198,7 +199,10 @@ class CharacterMessage(SceneMessage): try: message = self.message.split(":", 1)[1].strip() except IndexError: - log.warning("character_message_as_movie_script failed to parse correct format", msg=self.message) + log.warning( + "character_message_as_movie_script failed to parse correct format", + msg=self.message, + ) message = self.message return f"\n{self.character_name.upper()}\n{message}\nEND-OF-LINE\n" @@ -222,7 +226,6 @@ class NarratorMessage(SceneMessage): source: str = "ai" typ = "narrator" - def source_to_meta(self) -> dict: source = self.source action_name, *args = source.split(":") @@ -247,11 +250,7 @@ class NarratorMessage(SceneMessage): elif action_name == "narrate_after_dialogue": parameters["character"] = args[0] - return { - "agent": "narrator", - "function": action_name, - "arguments": parameters - } + return {"agent": "narrator", "function": action_name, "arguments": parameters} def migrate_source_to_meta(self): if self.source and not self.meta: @@ -262,6 +261,7 @@ class NarratorMessage(SceneMessage): return self + @dataclass class DirectorMessage(SceneMessage): action: str = "actor_instruction" @@ -278,7 +278,6 @@ class DirectorMessage(SceneMessage): @property def as_inner_monologue(self): - # instructions may be written referencing the character as you, your etc., # so we need to replace those to fit a first person perspective @@ -302,29 +301,28 @@ class DirectorMessage(SceneMessage): @property def as_story_progression(self): return f"{self.character_name}'s next action: {self.instructions}" - + @property def as_director_action(self) -> str: if not self.character_name: return f"{self.message}\n{self.action}" - #Become aggressive towards Elmer as you no longer recognize the man. + # Become aggressive towards Elmer as you no longer recognize the man. def migrate_message_to_meta(self): if self.message.startswith("Director instructs"): parts = self.message.split(":", 1) character_name = parts[0].replace("Director instructs ", "").strip() instructions = parts[1].strip() - + self.set_source( - "director", - "actor_instruction", + "director", + "actor_instruction", character=character_name, ) self.message = instructions self.source = "player" - + return self - def __dict__(self) -> dict: rv = super().__dict__() @@ -341,10 +339,9 @@ class DirectorMessage(SceneMessage): return self.as_format("chat") def as_format(self, format: str, **kwargs) -> str: - if not self.instructions.strip(): return "" - + mode = kwargs.get("mode", "direction") if format == "movie_script": if mode == "internal_monologue": @@ -369,6 +366,7 @@ class TimePassageMessage(SceneMessage): rv["ts"] = self.ts return rv + @dataclass class ReinforcementMessage(SceneMessage): typ = "reinforcement" @@ -383,9 +381,7 @@ class ReinforcementMessage(SceneMessage): return self.source_arguments.get("question", "question") def __str__(self): - return ( - f"# Internal note for {self.character_name} - {self.question}\n{self.message}" - ) + return f"# Internal note for {self.character_name} - {self.question}\n{self.message}" def as_format(self, format: str, **kwargs) -> str: if format == "movie_script": @@ -398,7 +394,9 @@ class ReinforcementMessage(SceneMessage): try: self.source_to_meta() except Exception as e: - log.warning("migrate_reinforcement_source_to_meta", error=e, msg=self.id) + log.warning( + "migrate_reinforcement_source_to_meta", error=e, msg=self.id + ) return self @@ -408,6 +406,7 @@ class ReinforcementMessage(SceneMessage): parameters = {"character": args[1], "question": args[0]} self.set_source("world_state", "update_reinforcement", **parameters) + @dataclass class ContextInvestigationMessage(SceneMessage): typ = "context_investigation" @@ -417,25 +416,25 @@ class ContextInvestigationMessage(SceneMessage): @property def character(self) -> str: return self.source_arguments.get("character", "character") - + @property def query(self) -> str: return self.source_arguments.get("query", "query") - + @property def title(self) -> str: """ The title will differ based on sub_type - + Current sub_types: - + - visual-character - visual-scene - query - + A natural language title will be generated based on the sub_type """ - + if self.sub_type == "visual-character": return f"Visual description of {self.character} in the current moment" elif self.sub_type == "visual-scene": @@ -443,17 +442,15 @@ class ContextInvestigationMessage(SceneMessage): elif self.sub_type == "query": return f"Query: {self.query}" return "Internal note" - + def __str__(self): - return ( - f"# {self.title}: {self.message}" - ) + return f"# {self.title}: {self.message}" def __dict__(self) -> dict: rv = super().__dict__() rv["sub_type"] = self.sub_type return rv - + def as_format(self, format: str, **kwargs) -> str: if format == "movie_script": message = str(self)[2:] diff --git a/src/talemate/server/api.py b/src/talemate/server/api.py index 8123164c..66f85c88 100644 --- a/src/talemate/server/api.py +++ b/src/talemate/server/api.py @@ -1,6 +1,5 @@ import asyncio import json -import os import traceback import starlette.websockets @@ -19,6 +18,7 @@ from talemate.game.engine.nodes.registry import import_initial_node_definitions log = structlog.get_logger("talemate") + async def websocket_endpoint(websocket): # Create a queue for outgoing messages message_queue = asyncio.Queue() @@ -26,13 +26,13 @@ async def websocket_endpoint(websocket): scene_task = None log.info("frontend connected") - + import_initial_node_definitions() async def frontend_disconnect(exc): nonlocal scene_task log.warning(f"frontend disconnected: {exc}") - + main_task.cancel() send_messages_task.cancel() send_status_task.cancel() @@ -56,7 +56,6 @@ async def websocket_endpoint(websocket): message = await message_queue.get() await websocket.send(json.dumps(message, cls=JSONEncoder)) - # Create a task to send regular client status updates async def send_status(): while True: @@ -77,7 +76,6 @@ async def websocket_endpoint(websocket): ) await asyncio.sleep(15) - # task to test connection async def test_connection(): while True: @@ -86,9 +84,7 @@ async def websocket_endpoint(websocket): except Exception as e: await frontend_disconnect(e) await asyncio.sleep(1) - - - + # main loop task async def handle_messages(): nonlocal scene_task @@ -187,10 +183,10 @@ async def websocket_endpoint(websocket): handler.scene.interrupt() elif action_type == "request_app_config": log.info("request_app_config") - + config = load_config() config.update(system_prompt_defaults=SYSTEM_PROMPTS_CACHE) - + await message_queue.put( { "type": "app_config", @@ -201,7 +197,6 @@ async def websocket_endpoint(websocket): else: log.info("Routing to sub-handler", action_type=action_type) await handler.route(data) - # handle disconnects except ( @@ -211,11 +206,16 @@ async def websocket_endpoint(websocket): ) as exc: await frontend_disconnect(exc) - main_task = asyncio.create_task(handle_messages()) send_messages_task = asyncio.create_task(send_messages()) send_status_task = asyncio.create_task(send_status()) send_client_bootstraps_task = asyncio.create_task(send_client_bootstraps()) test_connection_task = asyncio.create_task(test_connection()) - - await asyncio.gather(main_task, send_messages_task, send_status_task, send_client_bootstraps_task, test_connection_task) \ No newline at end of file + + await asyncio.gather( + main_task, + send_messages_task, + send_status_task, + send_client_bootstraps_task, + test_connection_task, + ) diff --git a/src/talemate/server/assistant.py b/src/talemate/server/assistant.py index b40a063a..1abd187d 100644 --- a/src/talemate/server/assistant.py +++ b/src/talemate/server/assistant.py @@ -13,6 +13,7 @@ class ForkScenePayload(pydantic.BaseModel): message_id: int save_name: str | None = None + class AssistantPlugin: router = "assistant" @@ -36,14 +37,15 @@ class AssistantPlugin: async def handle_contextual_generate(self, data: dict): payload = ContentGenerationContext(**data) creator = get_agent("creator") - + if payload.computed_context[0] == "acting_instructions": content = await creator.determine_character_dialogue_instructions( - self.scene.get_character(payload.character), instructions=payload.instructions + self.scene.get_character(payload.character), + instructions=payload.instructions, ) else: content = await creator.contextual_generate(payload) - + self.websocket_handler.queue_put( { "type": self.router, @@ -63,7 +65,6 @@ class AssistantPlugin: context_type, context_name = data.computed_context if context_type == "dialogue": - if not data.character: character = self.scene.get_player_character() else: @@ -87,7 +88,9 @@ class AssistantPlugin: data.length = 35 log.info("Running autocomplete for contextual generation", args=data) completion = await creator.contextual_generate(data) - log.info("Autocomplete for contextual generation complete", completion=completion) + log.info( + "Autocomplete for contextual generation complete", completion=completion + ) completion = ( completion.replace(f"{context_name}: {data.partial}", "") .lstrip(".") @@ -95,25 +98,24 @@ class AssistantPlugin: ) emit("autocomplete_suggestion", completion) - except Exception as e: + except Exception: log.error("Error running autocomplete", error=traceback.format_exc()) emit("autocomplete_suggestion", "") - async def handle_fork_new_scene(self, data: dict): """ Allows to fork a new scene from a specific message in the current scene. - + All content after the message will be removed and the context database will be re imported ensuring a clean state. - + All state reinforcements will be reset to their most recent state before the message. """ - + payload = ForkScenePayload(**data) - + creator = get_agent("creator") - - await creator.fork_scene(payload.message_id, payload.save_name) \ No newline at end of file + + await creator.fork_scene(payload.message_id, payload.save_name) diff --git a/src/talemate/server/character_importer.py b/src/talemate/server/character_importer.py index ead86b4c..f39684a9 100644 --- a/src/talemate/server/character_importer.py +++ b/src/talemate/server/character_importer.py @@ -1,6 +1,5 @@ import asyncio import json -import os import pydantic import structlog diff --git a/src/talemate/server/config.py b/src/talemate/server/config.py index bb5dd8d9..d368c2a5 100644 --- a/src/talemate/server/config.py +++ b/src/talemate/server/config.py @@ -41,6 +41,7 @@ class ToggleClientPayload(pydantic.BaseModel): class DeleteScenePayload(pydantic.BaseModel): path: str + class ConfigPlugin: router = "config" @@ -220,7 +221,6 @@ class ConfigPlugin: await emit_clients_status() - async def handle_remove_scene_from_recents(self, data): payload = DeleteScenePayload(**data) @@ -243,11 +243,11 @@ class ConfigPlugin: }, } ) - + self.websocket_handler.queue_put( {"type": "app_config", "data": load_config(), "version": VERSION} ) - + async def handle_delete_scene(self, data): payload = DeleteScenePayload(**data) @@ -271,4 +271,4 @@ class ConfigPlugin: self.websocket_handler.queue_put( {"type": "app_config", "data": load_config(), "version": VERSION} - ) \ No newline at end of file + ) diff --git a/src/talemate/server/devtools.py b/src/talemate/server/devtools.py index e94880ae..3bc17ffc 100644 --- a/src/talemate/server/devtools.py +++ b/src/talemate/server/devtools.py @@ -1,7 +1,5 @@ import pydantic import structlog -from talemate.game.state import GameState -from talemate.world_state import WorldState from talemate.scene.state_editor import SceneStateEditor from talemate.scene.schema import SceneState from talemate.server.websocket_plugin import Plugin @@ -16,9 +14,10 @@ class TestPromptPayload(pydantic.BaseModel): client_name: str kind: str + class SetSceneStatePayload(pydantic.BaseModel): state: SceneState - + def ensure_number(v): """ @@ -70,39 +69,31 @@ class DevToolsPlugin(Plugin): }, } ) - + async def handle_get_scene_state(self, data): scene = self.scene editor = SceneStateEditor(scene) state = editor.dump() - + self.websocket_handler.queue_put( - { - "type": "devtools", - "action": "scene_state", - "data": state - } + {"type": "devtools", "action": "scene_state", "data": state} ) - + async def handle_update_scene_state(self, data): scene = self.scene editor = SceneStateEditor(scene) - + try: payload = SetSceneStatePayload(**data) editor.load(payload.model_dump().get("state")) except Exception as exc: await self.signal_operation_failed(str(exc)) return - + emit("status", message="Scene state updated", status="success") - + self.websocket_handler.queue_put( - { - "type": "devtools", - "action": "scene_state_updated", - "data": editor.dump() - } + {"type": "devtools", "action": "scene_state_updated", "data": editor.dump()} ) - - await self.signal_operation_done() \ No newline at end of file + + await self.signal_operation_done() diff --git a/src/talemate/server/node_editor.py b/src/talemate/server/node_editor.py index a0803950..02248f4c 100644 --- a/src/talemate/server/node_editor.py +++ b/src/talemate/server/node_editor.py @@ -1,51 +1,62 @@ import pydantic import structlog import os -import json import asyncio from functools import wraps -from talemate.context import interaction - - -from talemate.game.engine.nodes.core import Graph, Loop, GraphState, Listen, graph_state, PASSTHROUGH_ERRORS, dynamic_node_import, load_extended_components +from talemate.game.engine.nodes.core import ( + GraphState, + PASSTHROUGH_ERRORS, +) from talemate.game.engine.nodes.scene import SceneLoop from talemate.game.engine.nodes.base_types import BASE_TYPES from talemate.game.engine.nodes.registry import ( - export_node_definitions, - import_node_definition, - normalize_registry_name, - get_node, + export_node_definitions, + import_node_definition, + normalize_registry_name, validate_registry_path, ) -from talemate.game.engine.nodes.layout import normalize_node_filename, export_flat_graph, import_flat_graph, load_graph, list_node_files, PathInfo +from talemate.game.engine.nodes.layout import ( + normalize_node_filename, + export_flat_graph, + import_flat_graph, + load_graph, + list_node_files, + PathInfo, +) from talemate.game.engine.nodes.run import BreakpointEvent from talemate.save import save_node_module import talemate.emit.async_signals as signals from .websocket_plugin import Plugin + log = structlog.get_logger("talemate.server.node_editor") TALEMATE_BASE_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "..") + def requires_creative_environment(fn): @wraps(fn) async def wrapper(self, data): if self.scene.environment != "creative": - return await self.signal_operation_failed("Cannot update scene nodes in non-creative environment") - + return await self.signal_operation_failed( + "Cannot update scene nodes in non-creative environment" + ) + return await fn(self, data) - + return wrapper class RequestNodeLibrary(pydantic.BaseModel): - pass + pass + class RequestNodeModule(pydantic.BaseModel): path: str - + + class RequestCreateNodeModule(pydantic.BaseModel): name: str registry: str @@ -54,34 +65,38 @@ class RequestCreateNodeModule(pydantic.BaseModel): module_type: str = "graph" nodes: dict | None = None + class ExportNodeModule(pydantic.BaseModel): node_definitions: dict graph: dict path_info: PathInfo + class RequestUpdateNodeModule(pydantic.BaseModel): path: str graph: dict set_as_main: bool = False - + + class RequestDeleteNodeModule(pydantic.BaseModel): path: str - + + class RequestTestRun(pydantic.BaseModel): graph: dict - + class NodeEditorPlugin(Plugin): router = "node_editor" - + def connect(self): signals.get("nodes_node_state").connect(self.handle_node_state) signals.get("nodes_breakpoint").connect(self.handle_breakpoint) - + def disconnect(self): signals.get("nodes_node_state").disconnect(self.handle_node_state) signals.get("nodes_breakpoint").disconnect(self.handle_breakpoint) - + async def handle_node_state(self, state: GraphState): self.websocket_handler.queue_put( { @@ -90,15 +105,13 @@ class NodeEditorPlugin(Plugin): "data": state.flattened, } ) - - async def handle_breakpoint(self, breakpoint:BreakpointEvent): - - + + async def handle_breakpoint(self, breakpoint: BreakpointEvent): absolute_module_path = breakpoint.module_path base_dir_path = os.path.abspath(TALEMATE_BASE_DIR) - + relative_module_path = os.path.relpath(absolute_module_path, base_dir_path) - + self.websocket_handler.queue_put( { "type": self.router, @@ -106,36 +119,35 @@ class NodeEditorPlugin(Plugin): "data": { "node": breakpoint.node.model_dump(), "module_path": relative_module_path, - } + }, } ) - + async def handle_request_node_library(self, data: dict): - files = list_node_files(search_paths=[self.scene.nodes_dir]) scene = self.scene - + # Define scene loop path scene_dir = scene.save_dir - + # Add separator to scene_dir for proper directory matching if not scene_dir.endswith(os.path.sep): scene_dir_prefix = scene_dir + os.path.sep else: scene_dir_prefix = scene_dir - + # Define sorting key function def sort_key(file_path): filename = os.path.basename(file_path) - + if file_path == scene_dir or file_path.startswith(scene_dir_prefix): return (1, filename.lower()) else: return (2, filename.lower()) - + # Apply sorting files = sorted(files, key=sort_key) - + self.websocket_handler.queue_put( { "type": self.router, @@ -143,120 +155,135 @@ class NodeEditorPlugin(Plugin): "data": files, } ) - + async def handle_create_mode_module(self, data: dict): - request = RequestCreateNodeModule(**data) - + log.debug("creating node", request=request) - + filename = normalize_node_filename(request.name) - + registry = request.registry.replace("$N", normalize_registry_name(request.name)) node_definitions = export_node_definitions() try: validate_registry_path(registry, node_definitions) except ValueError as e: return await self.signal_operation_failed(str(e)) - + if request.nodes and not request.copy_from and not request.extend_from: # Create a module from selected nodes - + if registry in node_definitions["nodes"]: - return await self.signal_operation_failed(f"Cannot create blank module at existing path: {registry}. If you intend to override it, create it as a copy.") - + return await self.signal_operation_failed( + f"Cannot create blank module at existing path: {registry}. If you intend to override it, create it as a copy." + ) + title = request.name.title() - + new_graph_cls = await self._create_module_from_nodes( self.scene, request.nodes, title, registry, request.module_type, - filename + filename, ) - + if not new_graph_cls: return await self.signal_operation_failed("Invalid module type") - + elif not request.copy_from and not request.extend_from: # create a new module from scratch - + if registry in node_definitions["nodes"]: - return await self.signal_operation_failed(f"Cannot create new module at existing path: {registry}. If you intend to override it, create it as a copy.") - + return await self.signal_operation_failed( + f"Cannot create new module at existing path: {registry}. If you intend to override it, create it as a copy." + ) + graph_cls = BASE_TYPES.get(request.module_type) if not graph_cls: return await self.signal_operation_failed("Invalid module type") - + title = request.name.title() - + graph_def = graph_cls(title=title).model_dump() graph_def["registry"] = registry - - new_graph_cls = import_node_definition(graph_def, self.scene._NODE_DEFINITIONS) + + new_graph_cls = import_node_definition( + graph_def, self.scene._NODE_DEFINITIONS + ) new_graph_cls._module_path = await save_node_module( self.scene, new_graph_cls(), filename, set_as_main=False ) - + elif request.extend_from: # extend from a node module (inheritance) - + extend_from = request.extend_from - extend_graph, _ = load_graph(extend_from, search_paths=[self.scene.nodes_dir]) - + extend_graph, _ = load_graph( + extend_from, search_paths=[self.scene.nodes_dir] + ) + if not extend_graph: - return await self.signal_operation_failed("Cannot extend from non-existent node") - + return await self.signal_operation_failed( + "Cannot extend from non-existent node" + ) + base_type = extend_graph.base_type - + graph_cls = BASE_TYPES.get(base_type) - + if not graph_cls: return await self.signal_operation_failed("Invalid module type") - - graph_def = graph_cls(title=request.name.title(), extends=extend_from).model_dump() + + graph_def = graph_cls( + title=request.name.title(), extends=extend_from + ).model_dump() graph_def["registry"] = registry - - new_graph_cls = import_node_definition(graph_def, self.scene._NODE_DEFINITIONS) - + + new_graph_cls = import_node_definition( + graph_def, self.scene._NODE_DEFINITIONS + ) + new_graph_cls._module_path = await save_node_module( self.scene, new_graph_cls(), filename, set_as_main=False ) - + elif request.copy_from: # copy from a node module - + copy_from = request.copy_from - + graph, _ = load_graph(copy_from, search_paths=[self.scene.nodes_dir]) graph.title = request.name.title() - + graph_def = graph.model_dump() graph_def["registry"] = registry - - new_graph_cls = import_node_definition(graph_def, self.scene._NODE_DEFINITIONS) - + + new_graph_cls = import_node_definition( + graph_def, self.scene._NODE_DEFINITIONS + ) + # if the scene NODE_DEFINITIONS does not currently have scene/SceneLoop base # type module in it and the incoming graph is a scene/SceneLoop, then set_as_main # to True set_as_main = False - + if isinstance(graph, SceneLoop): set_as_main = True for scene_node in self.scene._NODE_DEFINITIONS.values(): if scene_node.base_type == "scene/SceneLoop": set_as_main = False break - + new_graph_cls._registry = registry graph.registry = registry - + new_graph_cls._module_path = await save_node_module( self.scene, graph, filename, set_as_main=set_as_main ) - + self.websocket_handler.queue_put( { "type": self.router, @@ -264,13 +291,15 @@ class NodeEditorPlugin(Plugin): "data": filename, } ) - + await self.handle_request_node_module({"path": filename}) - - async def _create_module_from_nodes(self, scene, nodes_data, title, registry, module_type, filename): + + async def _create_module_from_nodes( + self, scene, nodes_data, title, registry, module_type, filename + ): """ Create a new module using selected nodes from the editor. - + Args: scene: The current scene context nodes_data: JSON data from convertSelectedGraphToJSON including nodes and connections @@ -278,24 +307,24 @@ class NodeEditorPlugin(Plugin): registry: Registry path for the new module module_type: Type of module to create (e.g., "graph", "command/Command") filename: Filename for the new module - + Returns: The created module class """ graph_cls = BASE_TYPES.get(module_type) if not graph_cls: return None - + # Create basic graph definition graph_def = graph_cls(title=title).model_dump() graph_def["registry"] = registry - + # Import the node definition to create the class new_graph_cls = import_node_definition(graph_def, scene._NODE_DEFINITIONS) - + # Create an instance of the graph new_graph = new_graph_cls() - + # Create a flat data structure compatible with import_flat_graph flat_data = { "nodes": nodes_data.get("nodes", []), @@ -306,35 +335,34 @@ class NodeEditorPlugin(Plugin): "registry": registry, "base_type": module_type, "title": title, - "extends": None + "extends": None, } - + # Use import_flat_graph to properly create the graph populated_graph = import_flat_graph(flat_data, new_graph) - + # Save the module new_graph_cls._module_path = await save_node_module( scene, populated_graph, filename, set_as_main=False ) - - return new_graph_cls - + + return new_graph_cls + async def handle_request_node_module(self, data: dict): - request = RequestNodeModule(**data) - + graph, path_info = load_graph(request.path, search_paths=[self.scene.nodes_dir]) - + export_nodes = ExportNodeModule( - graph = export_flat_graph(graph), - node_definitions = export_node_definitions(), - path_info = path_info, + graph=export_flat_graph(graph), + node_definitions=export_node_definitions(), + path_info=path_info, ) - - #with open("exported_nodes.json", "w") as file: + + # with open("exported_nodes.json", "w") as file: # import json # json.dump(export_nodes, file, indent=2) - + self.websocket_handler.queue_put( { "type": self.router, @@ -342,36 +370,37 @@ class NodeEditorPlugin(Plugin): "data": export_nodes.model_dump(), } ) - + try: await self.handle_node_state(self.scene.nodegraph_state) except AttributeError: pass - + @requires_creative_environment async def handle_update_node_module(self, data: dict): import_nodes = RequestUpdateNodeModule(**data) graph = import_flat_graph(import_nodes.graph) - - + # ensure absolute path base_dir = os.path.abspath(TALEMATE_BASE_DIR) if not import_nodes.path.startswith(base_dir): import_nodes.path = os.path.join(base_dir, import_nodes.path) - + log.debug("updating nodes", path=import_nodes.path) - + if graph.registry: - node_cls = import_node_definition(graph.model_dump(), self.scene._NODE_DEFINITIONS, reimport=True) + node_cls = import_node_definition( + graph.model_dump(), self.scene._NODE_DEFINITIONS, reimport=True + ) node_cls._module_path = import_nodes.path - + await save_node_module(self.scene, graph, import_nodes.path) - + if graph.base_type == "scene/SceneLoop": if import_nodes.set_as_main: self.scene.nodes_filename = import_nodes.path self.scene.saved = False - + self.websocket_handler.queue_put( { "type": self.router, @@ -379,33 +408,34 @@ class NodeEditorPlugin(Plugin): "data": {}, } ) - + @requires_creative_environment async def handle_delete_node_module(self, data: dict): request = RequestDeleteNodeModule(**data) - + # only scene nodes can be deleted # check by checking against the scene's save_dir property path = os.path.join(TALEMATE_BASE_DIR, request.path) - + path = os.path.abspath(path) - + log.debug("deleting", path=path, reqest=request) - + if not path.startswith(self.scene.save_dir): - return await self.signal_operation_failed("Cannot delete node module outside of scene directory") - + return await self.signal_operation_failed( + "Cannot delete node module outside of scene directory" + ) + try: os.remove(path) except FileNotFoundError: pass - + for scene_node in list(self.scene._NODE_DEFINITIONS.values()): if scene_node._module_path == path: self.scene._NODE_DEFINITIONS.pop(scene_node._registry, None) break - - + self.websocket_handler.queue_put( { "type": self.router, @@ -413,22 +443,21 @@ class NodeEditorPlugin(Plugin): "path": request.path, } ) - + await self.handle_request_node_library({}) - + @requires_creative_environment async def handle_test_run(self, data: dict): """ Loads a graph from json and runs it. """ - - scene = self.scene + payload = RequestTestRun(**data) - + graph = import_flat_graph(payload.graph) - + await self._start_test_with_graph(graph) - + self.websocket_handler.queue_put( { "type": self.router, @@ -436,19 +465,19 @@ class NodeEditorPlugin(Plugin): "data": payload.model_dump(), } ) - + @requires_creative_environment async def handle_test_run_scene_loop(self, data: dict): """ Loads the scene's main loop and runs it. """ scene = self.scene - + # Load the scene's main loop graph, _ = load_graph(scene.nodes_filename, search_paths=[scene.nodes_dir]) - + await self._start_test_with_graph(graph) - + self.websocket_handler.queue_put( { "type": self.router, @@ -456,16 +485,16 @@ class NodeEditorPlugin(Plugin): "data": {}, } ) - + async def _start_test_with_graph(self, graph): """ Common logic for starting a test with a loaded graph """ - active_graph_state:GraphState = self.scene.nodegraph_state - - async def on_error(state:GraphState, error:Exception): + active_graph_state: GraphState = self.scene.nodegraph_state + + async def on_error(state: GraphState, error: Exception): if isinstance(error, PASSTHROUGH_ERRORS): - return + return self.websocket_handler.queue_put( { "type": self.router, @@ -475,10 +504,10 @@ class NodeEditorPlugin(Plugin): }, } ) - + await self.handle_test_stop({}) - - async def on_done(state:GraphState): + + async def on_done(state: GraphState): self.websocket_handler.queue_put( { "type": self.router, @@ -486,29 +515,27 @@ class NodeEditorPlugin(Plugin): "data": {}, } ) - + await self.handle_test_stop({}) - + graph.callbacks.append(on_done) graph.error_handlers.append(on_error) - + if isinstance(graph, SceneLoop): graph.properties["trigger_game_loop"] = True - + active_graph_state.shared["__test_module"] = graph - + @requires_creative_environment async def handle_test_restart(self, data: dict): await self._stop_test() await asyncio.sleep(1) await self.handle_test_run(data) - @requires_creative_environment async def handle_test_stop(self, data: dict): - await self._stop_test() - + self.websocket_handler.queue_put( { "type": self.router, @@ -516,10 +543,10 @@ class NodeEditorPlugin(Plugin): "data": {}, } ) - + @requires_creative_environment async def handle_release_breakpoint(self, data: dict): - active_graph_state:GraphState = self.scene.nodegraph_state + active_graph_state: GraphState = self.scene.nodegraph_state active_graph_state.shared.pop("__breakpoint", None) self.websocket_handler.queue_put( { @@ -530,12 +557,12 @@ class NodeEditorPlugin(Plugin): ) async def _stop_test(self): - active_graph_state:GraphState = self.scene.nodegraph_state + active_graph_state: GraphState = self.scene.nodegraph_state module = active_graph_state.shared.pop("__test_module", None) - + if not module: return - + task = active_graph_state.shared.pop(f"__run_{module.id}", None) if task: - task.cancel() \ No newline at end of file + task.cancel() diff --git a/src/talemate/server/package_manager.py b/src/talemate/server/package_manager.py index 361f1ca9..b6c8ffc9 100644 --- a/src/talemate/server/package_manager.py +++ b/src/talemate/server/package_manager.py @@ -6,7 +6,6 @@ from .websocket_plugin import Plugin from talemate.game.engine.nodes.packaging import ( list_packages, apply_scene_package_info, - get_scene_package_info, install_package, get_package_by_registry, uninstall_package, @@ -16,12 +15,15 @@ from talemate.game.engine.nodes.packaging import ( log = structlog.get_logger(__name__) + class InstallPackageRequest(pydantic.BaseModel): package_registry: str + class UninstallPackageRequest(pydantic.BaseModel): package_registry: str + class SavePackagePropertiesRequest(pydantic.BaseModel): package_registry: str package_properties: dict[str, PackageProperty] @@ -29,47 +31,46 @@ class SavePackagePropertiesRequest(pydantic.BaseModel): class PackageManagerPlugin(Plugin): router = "package_manager" - + def connect(self): pass - + def disconnect(self): pass - + async def handle_request_package_list(self, data: dict): packages = await list_packages() - + scene = self.scene - + await apply_scene_package_info(scene, packages) - - + # sort by stats, then name (installed first) packages.sort(key=lambda x: (x.status == "installed", x.name)) - + self.websocket_handler.queue_put( { "type": self.router, "action": "package_list", - "data": [ - package.model_dump() for package in packages - ], + "data": [package.model_dump() for package in packages], } ) - + async def handle_install_package(self, data: dict): request = InstallPackageRequest(**data) - + scene = self.scene - + package = await get_package_by_registry(request.package_registry) - + if not package: - await self.signal_operation_failed(f"Package with registry {request.package_registry} not found") + await self.signal_operation_failed( + f"Package with registry {request.package_registry} not found" + ) return - + await install_package(scene, package) - + self.websocket_handler.queue_put( { "type": self.router, @@ -79,18 +80,18 @@ class PackageManagerPlugin(Plugin): }, } ) - + await self.handle_request_package_list(data) - + async def handle_uninstall_package(self, data: dict): request = UninstallPackageRequest(**data) - + scene = self.scene - + package = await get_package_by_registry(request.package_registry) - + await uninstall_package(scene, request.package_registry) - + self.websocket_handler.queue_put( { "type": self.router, @@ -100,18 +101,22 @@ class PackageManagerPlugin(Plugin): }, } ) - + await self.handle_request_package_list(data) - + async def handle_save_package_properties(self, data: dict): request = SavePackagePropertiesRequest(**data) - + scene = self.scene - - package = await update_package_properties(scene, request.package_registry, request.package_properties) - + + package = await update_package_properties( + scene, request.package_registry, request.package_properties + ) + if not package: - await self.signal_operation_failed(f"Package with registry {request.package_registry} not found") + await self.signal_operation_failed( + f"Package with registry {request.package_registry} not found" + ) return self.websocket_handler.queue_put( @@ -124,4 +129,4 @@ class PackageManagerPlugin(Plugin): } ) - await self.handle_request_package_list(data) \ No newline at end of file + await self.handle_request_package_list(data) diff --git a/src/talemate/server/quick_settings.py b/src/talemate/server/quick_settings.py index c1616933..10784fda 100644 --- a/src/talemate/server/quick_settings.py +++ b/src/talemate/server/quick_settings.py @@ -1,10 +1,9 @@ -import uuid -from typing import Any, Union +from typing import Any import pydantic import structlog -from talemate.config import load_config, save_config +from talemate.config import save_config log = structlog.get_logger("talemate.server.quick_settings") diff --git a/src/talemate/server/run.py b/src/talemate/server/run.py index ca5274e9..712d7498 100644 --- a/src/talemate/server/run.py +++ b/src/talemate/server/run.py @@ -11,9 +11,7 @@ import signal import sys import websockets -import re -import talemate.config from talemate.server.api import websocket_endpoint from talemate.version import VERSION @@ -38,21 +36,23 @@ STARTUP_TEXT = f""" v{VERSION} """ + async def install_punkt(): import nltk - + log.info("Downloading NLTK punkt tokenizer") await asyncio.get_event_loop().run_in_executor(None, nltk.download, "punkt") await asyncio.get_event_loop().run_in_executor(None, nltk.download, "punkt_tab") log.info("Download complete") + async def log_stream(stream, log_func): while True: line = await stream.readline() if not line: break decoded_line = line.decode().strip() - + # Check if the original line started with "INFO:" (Uvicorn startup messages) if decoded_line.startswith("INFO:"): # Use info level for Uvicorn startup messages @@ -61,6 +61,7 @@ async def log_stream(stream, log_func): # Use the provided log_func for other messages log_func("uvicorn", message=decoded_line) + async def run_frontend(host: str = "localhost", port: int = 8080): if sys.platform == "win32": activate_cmd = ".\\.venv\\Scripts\\activate.bat" @@ -68,23 +69,28 @@ async def run_frontend(host: str = "localhost", port: int = 8080): else: frontend_cmd = f"/bin/bash -c 'source .venv/bin/activate && uvicorn --host {host} --port {port} frontend_wsgi:application'" frontend_cwd = None - + process = await asyncio.create_subprocess_shell( frontend_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=frontend_cwd, shell=True, - preexec_fn=os.setsid if sys.platform != "win32" else None + preexec_fn=os.setsid if sys.platform != "win32" else None, ) - - - log.info(f"talemate frontend started", host=host, port=port, server="uvicorn", process=process.pid) - + + log.info( + "talemate frontend started", + host=host, + port=port, + server="uvicorn", + process=process.pid, + ) + try: stdout_task = asyncio.create_task(log_stream(process.stdout, log.info)) stderr_task = asyncio.create_task(log_stream(process.stderr, log.error)) - + await asyncio.gather(stdout_task, stderr_task) await process.wait() finally: @@ -95,11 +101,13 @@ async def run_frontend(host: str = "localhost", port: int = 8080): os.killpg(os.getpgid(process.pid), signal.SIGTERM) await process.wait() + async def cancel_all_tasks(loop): tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()] [task.cancel() for task in tasks] await asyncio.gather(*tasks, return_exceptions=True) + def run_server(args): """ Run the talemate web server using the provided arguments. @@ -115,19 +123,18 @@ def run_server(args): from talemate.prompts.overrides import get_template_overrides import talemate.client.system_prompts as system_prompts from talemate.emit.base import emit - + # import node libraries import talemate.game.engine.nodes.load_definitions - config = talemate.config.cleanup() if config.game.world_state.templates.state_reinforcement: Collection.create_from_legacy_config(config) - + # pre-cache system prompts system_prompts.cache_all() - + for agent_type in get_agent_types(): template_overrides = get_template_overrides(agent_type) for template_override in template_overrides: @@ -156,30 +163,32 @@ def run_server(args): websocket_endpoint, args.host, args.port, - max_size=2 ** 23, + max_size=2**23, ) # Start the websocket server and keep a reference so we can shut it down websocket_server = loop.run_until_complete(_start_websocket_server()) - + # start task to unstall punkt loop.create_task(install_punkt()) - + if not args.backend_only: - frontend_task = loop.create_task(run_frontend(args.frontend_host, args.frontend_port)) + frontend_task = loop.create_task( + run_frontend(args.frontend_host, args.frontend_port) + ) else: frontend_task = None log.info("talemate backend started", host=args.host, port=args.port) emit("talemate_started", data=config.model_dump()) - + try: loop.run_forever() except KeyboardInterrupt: pass finally: log.info("Shutting down...") - + try: if frontend_task: frontend_task.cancel() @@ -195,11 +204,14 @@ def run_server(args): except KeyboardInterrupt: # If the user hits Ctrl+C again during shutdown, exit quickly without # another traceback. - log.warning("Forced termination requested during shutdown - exiting immediately") + log.warning( + "Forced termination requested during shutdown - exiting immediately" + ) finally: loop.close() log.info("Shutdown complete") + def main(): parser = argparse.ArgumentParser(description="talemate server") subparser = parser.add_subparsers(dest="command") @@ -210,14 +222,20 @@ def main(): ) runserver_parser.add_argument("--host", default="localhost", help="Hostname") runserver_parser.add_argument("--port", type=int, default=6000, help="Port") - runserver_parser.add_argument("--backend-only", action="store_true", help="Run the backend only") + runserver_parser.add_argument( + "--backend-only", action="store_true", help="Run the backend only" + ) # frontend host and port - runserver_parser.add_argument("--frontend-host", default="localhost", help="Frontend Hostname") - runserver_parser.add_argument("--frontend-port", type=int, default=8080, help="Frontend Port") + runserver_parser.add_argument( + "--frontend-host", default="localhost", help="Frontend Hostname" + ) + runserver_parser.add_argument( + "--frontend-port", type=int, default=8080, help="Frontend Port" + ) args = parser.parse_args() - + # wipe screen if backend only mode is not enabled # reason: backend only is run usually in dev mode and may be worth keeping the console output if not args.backend_only: @@ -234,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/talemate/server/websocket_plugin.py b/src/talemate/server/websocket_plugin.py index 0ec07c24..86cf5943 100644 --- a/src/talemate/server/websocket_plugin.py +++ b/src/talemate/server/websocket_plugin.py @@ -18,10 +18,10 @@ class Plugin: def __init__(self, websocket_handler): self.websocket_handler = websocket_handler self.connect() - + def connect(self): pass - + def disconnect(self): pass @@ -35,13 +35,14 @@ class Plugin: ) if emit_status: emit("status", message=message, status="error") - - async def signal_operation_done(self, signal_only: bool = False, allow_auto_save: bool = True): + async def signal_operation_done( + self, signal_only: bool = False, allow_auto_save: bool = True + ): self.websocket_handler.queue_put( {"type": self.router, "action": "operation_done", "data": {}} ) - + if signal_only: return @@ -51,11 +52,10 @@ class Plugin: self.scene.saved = False self.scene.emit_status() - async def handle(self, data: dict): log.info(f"{self.router} action", action=data.get("action")) fn = getattr(self, f"handle_{data.get('action')}", None) if fn is None: return - await fn(data) \ No newline at end of file + await fn(data) diff --git a/src/talemate/server/websocket_server.py b/src/talemate/server/websocket_server.py index 50e2d61d..f6de03a7 100644 --- a/src/talemate/server/websocket_server.py +++ b/src/talemate/server/websocket_server.py @@ -11,7 +11,7 @@ from talemate.client.base import ClientBase from talemate.client.registry import CLIENT_CLASSES from talemate.client.system_prompts import RENDER_CACHE as SYSTEM_PROMPTS_CACHE from talemate.config import SceneAssetUpload, load_config, save_config -from talemate.context import ActiveScene, active_scene +from talemate.context import ActiveScene from talemate.emit import Emission, Receiver, abort_wait_for_input, emit from talemate.files import list_scenes_directory from talemate.load import load_scene @@ -68,7 +68,9 @@ class WebsocketHandler(Receiver): ), devtools.DevToolsPlugin.router: devtools.DevToolsPlugin(self), node_editor.NodeEditorPlugin.router: node_editor.NodeEditorPlugin(self), - package_manager.PackageManagerPlugin.router: package_manager.PackageManagerPlugin(self), + package_manager.PackageManagerPlugin.router: package_manager.PackageManagerPlugin( + self + ), } # unconveniently named function, this `connect` method is called @@ -86,7 +88,6 @@ class WebsocketHandler(Receiver): # instance.emit_clients_status() def set_agent_routers(self): - for agent_type, agent in instance.AGENTS.items(): handler_cls = getattr(agent, "websocket_handler", None) if not handler_cls or handler_cls.router in self.routes: @@ -104,7 +105,7 @@ class WebsocketHandler(Receiver): memory_agent = instance.get_agent("memory") if memory_agent and self.scene: memory_agent.close_db(self.scene) - + for plugin in self.routes.values(): if hasattr(plugin, "disconnect"): plugin.disconnect() @@ -140,7 +141,7 @@ class WebsocketHandler(Receiver): for agent_typ, agent_config in self.agents.items(): try: client = self.llm_clients.get(agent_config.get("client"))["client"] - except TypeError as e: + except TypeError: client = None if not client or not client.enabled: @@ -217,13 +218,15 @@ class WebsocketHandler(Receiver): with ActiveScene(scene): try: scene = await load_scene( - scene, path_or_data, conversation_helper.agent.client, reset=reset + scene, + path_or_data, + conversation_helper.agent.client, + reset=reset, ) except MemoryAgentError as e: emit("status", message=str(e), status="error") log.error("load_scene", error=str(e)) return - self.scene = scene @@ -399,7 +402,7 @@ class WebsocketHandler(Receiver): **emission.kwargs, } ) - except Exception as e: + except Exception: log.error("emission passthrough", error=traceback.format_exc()) def handle_system(self, emission: Emission): @@ -483,15 +486,23 @@ class WebsocketHandler(Receiver): ), } ) - + def handle_context_investigation(self, emission: Emission): self.queue_put( { "type": "context_investigation", - "sub_type": emission.message_object.sub_type if emission.message_object else None, - "source_agent": emission.message_object.source_agent if emission.message_object else None, - "source_function": emission.message_object.source_function if emission.message_object else None, - "source_arguments": emission.message_object.source_arguments if emission.message_object else None, + "sub_type": emission.message_object.sub_type + if emission.message_object + else None, + "source_agent": emission.message_object.source_agent + if emission.message_object + else None, + "source_function": emission.message_object.source_function + if emission.message_object + else None, + "source_arguments": emission.message_object.source_arguments + if emission.message_object + else None, "message": emission.message, "id": emission.id, "flags": ( @@ -543,9 +554,8 @@ class WebsocketHandler(Receiver): ) def handle_config_saved(self, emission: Emission): - emission.data.update(system_prompt_defaults=SYSTEM_PROMPTS_CACHE) - + self.queue_put( { "type": "app_config", @@ -583,7 +593,6 @@ class WebsocketHandler(Receiver): "data": emission.data, "max_token_length": client.max_token_length if client else 8192, "api_url": getattr(client, "api_url", None) if client else None, - "api_url": getattr(client, "api_url", None) if client else None, "api_key": getattr(client, "api_key", None) if client else None, } ) @@ -671,9 +680,6 @@ class WebsocketHandler(Receiver): self.queue_put({"type": "processing_input"}) return - player_character = self.scene.get_player_character() - player_character_name = player_character.name if player_character else "" - self.queue_put( { "type": "processing_input", @@ -748,7 +754,7 @@ class WebsocketHandler(Receiver): "media_type": scene_assets.get_asset(asset_id).media_type, } ) - except Exception as exc: + except Exception: log.error("request_scene_assets", error=traceback.format_exc()) def request_assets(self, assets: list[dict]): @@ -763,7 +769,7 @@ class WebsocketHandler(Receiver): for asset_dict in assets: try: asset_id, asset = self._asset(**asset_dict) - except Exception as exc: + except Exception: log.error("request_assets", error=traceback.format_exc(), **asset_dict) continue _assets[asset_id] = asset @@ -828,7 +834,6 @@ class WebsocketHandler(Receiver): self.scene.delete_message(message_id) def edit_message(self, message_id, new_text): - message = self.scene.get_message(message_id) editor = instance.get_agent("editor") @@ -843,7 +848,6 @@ class WebsocketHandler(Receiver): self.scene.edit_message(message_id, new_text) def handle_character_card_upload(self, image_data_url: str, filename: str) -> str: - image_type = image_data_url.split(";")[0].split(":")[1] image_data = base64.b64decode(image_data_url.split(",")[1]) characters_path = os.path.join("./scenes", "characters") diff --git a/src/talemate/server/world_state_manager/__init__.py b/src/talemate/server/world_state_manager/__init__.py index 8ae2baa2..c0028450 100644 --- a/src/talemate/server/world_state_manager/__init__.py +++ b/src/talemate/server/world_state_manager/__init__.py @@ -178,15 +178,13 @@ class GenerateSuggestionPayload(pydantic.BaseModel): generation_options: world_state_templates.GenerationOptions | None = None instructions: str | None = None + class SuggestionPayload(pydantic.BaseModel): id: str proposal_uid: str | None = None -class WorldStateManagerPlugin( - SceneIntentMixin, - HistoryMixin, - Plugin -): + +class WorldStateManagerPlugin(SceneIntentMixin, HistoryMixin, Plugin): router = "world_state_manager" @property @@ -516,10 +514,11 @@ class WorldStateManagerPlugin( None, payload.question, payload.reset ) - _, reinforcement = ( - await self.world_state_manager.world_state.find_reinforcement( - payload.question, None - ) + ( + _, + reinforcement, + ) = await self.world_state_manager.world_state.find_reinforcement( + payload.question, None ) if not reinforcement: @@ -797,7 +796,6 @@ class WorldStateManagerPlugin( await self.signal_operation_done() async def handle_delete_template_group(self, data): - payload = DeleteWorldStateTemplateGroupPayload(**data) group = payload.group @@ -867,7 +865,7 @@ class WorldStateManagerPlugin( } ) - await self.scene.remove_character(character) + await self.scene.remove_character(character) await self.signal_operation_done() await self.handle_get_character_list({}) self.scene.emit_status() @@ -1009,7 +1007,7 @@ class WorldStateManagerPlugin( payload = SaveScenePayload(**data) log.debug("Save scene", copy=payload.save_as, project_name=payload.project_name) - + if not self.scene.filename: # scene has never been saved before # specify project name (directory name) @@ -1019,12 +1017,12 @@ class WorldStateManagerPlugin( self.scene.emit_status() # Suggestions - + async def handle_request_suggestions(self, data): """ Request current suggestions from the world state. """ - + world_state_dict = self.scene.world_state.model_dump() suggestions = world_state_dict.get("suggestions", []) self.websocket_handler.queue_put( @@ -1034,13 +1032,15 @@ class WorldStateManagerPlugin( "data": suggestions, } ) - + async def handle_remove_suggestion(self, data): payload = SuggestionPayload(**data) if not payload.proposal_uid: await self.world_state_manager.remove_suggestion(payload.id) else: - await self.world_state_manager.remove_suggestion_proposal(payload.id, payload.proposal_uid) + await self.world_state_manager.remove_suggestion_proposal( + payload.id, payload.proposal_uid + ) self.websocket_handler.queue_put( { @@ -1049,38 +1049,36 @@ class WorldStateManagerPlugin( "data": payload.model_dump(), } ) - async def handle_generate_suggestions(self, data): """ Generate's suggestions for character development. """ - + world_state = get_agent("world_state") - world_state_manager:WorldStateManager = self.scene.world_state_manager + world_state_manager: WorldStateManager = self.scene.world_state_manager payload = GenerateSuggestionPayload(**data) - + log.debug("Generate suggestions", payload=payload) - - async def send_suggestion(call:focal.Call): + + async def send_suggestion(call: focal.Call): await world_state_manager.add_suggestion( Suggestion( name=payload.name, type=payload.suggestion_type, id=f"{payload.suggestion_type}-{payload.name}", - proposals=[call] + proposals=[call], ) ) - + with focal.FocalContext() as focal_context: - if payload.suggestion_type == "character": character = self.scene.get_character(payload.name) - + if not character: log.error("Character not found", name=payload.name) return - + self.websocket_handler.queue_put( { "type": "world_state_manager", @@ -1090,20 +1088,29 @@ class WorldStateManagerPlugin( "name": payload.name, } ) - + if not payload.auto_apply: focal_context.hooks_before_call.append(send_suggestion) focal_context.hooks_after_call.append(send_suggestion) - - @set_loading("Analyzing character development", cancellable=True, set_success=True, set_error=True) + + @set_loading( + "Analyzing character development", + cancellable=True, + set_success=True, + set_error=True, + ) async def task_wrapper(): await world_state.determine_character_development( - character, + character, generation_options=payload.generation_options, instructions=payload.instructions, ) - + task = asyncio.create_task(task_wrapper()) - - task.add_done_callback(lambda _: asyncio.create_task(self.handle_request_suggestions({}))) - task.add_done_callback(lambda _: asyncio.create_task(self.signal_operation_done())) + + task.add_done_callback( + lambda _: asyncio.create_task(self.handle_request_suggestions({})) + ) + task.add_done_callback( + lambda _: asyncio.create_task(self.signal_operation_done()) + ) diff --git a/src/talemate/server/world_state_manager/history.py b/src/talemate/server/world_state_manager/history.py index f2ec1599..2d4823a8 100644 --- a/src/talemate/server/world_state_manager/history.py +++ b/src/talemate/server/world_state_manager/history.py @@ -4,26 +4,27 @@ import structlog from talemate.instance import get_agent from talemate.history import ( - history_with_relative_time, - rebuild_history, + history_with_relative_time, + rebuild_history, HistoryEntry, update_history_entry, regenerate_history_entry, collect_source_entries, add_history_entry, - delete_history_entry + delete_history_entry, ) from talemate.server.world_state_manager import world_state_templates from talemate.util.time import amount_unit_to_iso8601_duration log = structlog.get_logger("talemate.server.world_state_manager.history") + class RegenerateHistoryPayload(pydantic.BaseModel): generation_options: world_state_templates.GenerationOptions | None = None class HistoryEntryPayload(pydantic.BaseModel): - entry: HistoryEntry + entry: HistoryEntry class AddHistoryEntryPayload(pydantic.BaseModel): @@ -31,24 +32,25 @@ class AddHistoryEntryPayload(pydantic.BaseModel): amount: int unit: str + class HistoryMixin: - """ Handles history-related operations for the world state manager. """ - + async def handle_request_scene_history(self, data): - """ Request the entire history for the scene. """ - - history = history_with_relative_time(self.scene.archived_history, self.scene.ts, layer=0) - + + history = history_with_relative_time( + self.scene.archived_history, self.scene.ts, layer=0 + ) + layered_history = [] - + summarizer = get_agent("summarizer") - + if summarizer.layered_history_enabled: for index, layer in enumerate(self.scene.layered_history): layered_history.append( @@ -56,27 +58,35 @@ class HistoryMixin: ) self.websocket_handler.queue_put( - {"type": "world_state_manager", "action": "scene_history", "data": { - "history": history, - "layered_history": layered_history, - }} + { + "type": "world_state_manager", + "action": "scene_history", + "data": { + "history": history, + "layered_history": layered_history, + }, + } ) async def handle_regenerate_history(self, data): """ Regenerate the history for the scene. """ - + payload = RegenerateHistoryPayload(**data) - + async def callback(): self.scene.emit_status() await self.handle_request_scene_history(data) - task = asyncio.create_task(rebuild_history( - self.scene, callback=callback, generation_options=payload.generation_options - )) - + task = asyncio.create_task( + rebuild_history( + self.scene, + callback=callback, + generation_options=payload.generation_options, + ) + ) + async def done(): self.websocket_handler.queue_put( { @@ -88,63 +98,62 @@ class HistoryMixin: await self.signal_operation_done() await self.handle_request_scene_history(data) - + # when task is done, queue a message to the client task.add_done_callback(lambda _: asyncio.create_task(done())) - + async def handle_update_history_entry(self, data): payload = HistoryEntryPayload(**data) - + entry = await update_history_entry(self.scene, payload.entry) - + self.websocket_handler.queue_put( { - "type": "world_state_manager", - "action": "history_entry_updated", - "data": entry.model_dump() + "type": "world_state_manager", + "action": "history_entry_updated", + "data": entry.model_dump(), } ) - + await self.signal_operation_done() - + async def handle_regenerate_history_entry(self, data): """ Regenerate a single history entry. """ - + payload = HistoryEntryPayload(**data) - + log.debug("regenerate_history_entry", payload=payload) - + try: entry = await regenerate_history_entry(self.scene, payload.entry) except Exception as e: log.error("regenerate_history_entry", error=e) await self.signal_operation_failed(str(e)) return - + log.debug("regenerate_history_entry (done)", entry=entry) - + self.websocket_handler.queue_put( { "type": "world_state_manager", "action": "history_entry_regenerated", - "data": entry.model_dump() + "data": entry.model_dump(), } ) - + await self.signal_operation_done() - - + async def handle_inspect_history_entry(self, data): """ Inspect a single history entry. """ - + payload = HistoryEntryPayload(**data) - + entries = collect_source_entries(self.scene, payload.entry) - + self.websocket_handler.queue_put( { "type": "world_state_manager", @@ -152,7 +161,7 @@ class HistoryMixin: "data": { "entries": [entry.model_dump() for entry in entries], "entry": payload.entry.model_dump(), - } + }, } ) @@ -164,7 +173,9 @@ class HistoryMixin: payload = AddHistoryEntryPayload(**data) try: - iso_offset = amount_unit_to_iso8601_duration(int(payload.amount), payload.unit) + iso_offset = amount_unit_to_iso8601_duration( + int(payload.amount), payload.unit + ) except ValueError as e: await self.signal_operation_failed(str(e)) return @@ -197,4 +208,4 @@ class HistoryMixin: # Send updated history to client await self.handle_request_scene_history({}) - await self.signal_operation_done() \ No newline at end of file + await self.signal_operation_done() diff --git a/src/talemate/server/world_state_manager/scene_intent.py b/src/talemate/server/world_state_manager/scene_intent.py index 40309652..fec0115b 100644 --- a/src/talemate/server/world_state_manager/scene_intent.py +++ b/src/talemate/server/world_state_manager/scene_intent.py @@ -1,41 +1,37 @@ -import pydantic import structlog from talemate.scene.schema import SceneIntent log = structlog.get_logger("talemate.server.world_state_manager.scene_intent") + class SceneIntentMixin: - - async def handle_get_scene_intent(self, data:dict): - - scene_intent:SceneIntent = self.scene.intent_state - + async def handle_get_scene_intent(self, data: dict): + scene_intent: SceneIntent = self.scene.intent_state + self.websocket_handler.queue_put( { "type": "world_state_manager", "action": "get_scene_intent", - "data": scene_intent.model_dump() + "data": scene_intent.model_dump(), } ) - - async def handle_set_scene_intent(self, data:dict): - - scene_intent:SceneIntent = SceneIntent(**data) - + + async def handle_set_scene_intent(self, data: dict): + scene_intent: SceneIntent = SceneIntent(**data) + self.scene.intent_state = scene_intent - + log.debug("Scene intent set", scene_intent=scene_intent) self.websocket_handler.queue_put( { "type": "world_state_manager", "action": "set_scene_intent", - "data": scene_intent.model_dump() + "data": scene_intent.model_dump(), } ) self.scene.emit_status() await self.signal_operation_done() - \ No newline at end of file diff --git a/src/talemate/status.py b/src/talemate/status.py index 4dbe2ca0..a88de165 100644 --- a/src/talemate/status.py +++ b/src/talemate/status.py @@ -1,6 +1,5 @@ import asyncio import structlog -import traceback from talemate.emit import emit from talemate.exceptions import GenerationCancelled @@ -16,8 +15,8 @@ log = structlog.get_logger("talemate.status") class set_loading: def __init__( - self, - message, + self, + message, set_busy: bool = True, set_success: bool = False, set_error: bool = False, @@ -58,8 +57,9 @@ class set_loading: # if as_async we want to wrap the function in a coroutine # that adds a task to the event loop and returns the task - + if self.as_async: + async def async_wrapper(*args, **kwargs): return asyncio.create_task(wrapper(*args, **kwargs)) @@ -76,12 +76,12 @@ class LoadingStatus: def __call__(self, message: str): self.current_step += 1 - + if self.max_steps is None: counter = "" else: counter = f" [{self.current_step}/{self.max_steps}]" - + emit( "status", message=f"{message}{counter}", @@ -90,11 +90,11 @@ class LoadingStatus: "cancellable": self.cancellable, }, ) - + def done(self, message: str = "", status: str = "idle"): if self.current_step == 0: return - + emit( "status", message=message, diff --git a/src/talemate/tale_mate.py b/src/talemate/tale_mate.py index aa9f682f..6eede35b 100644 --- a/src/talemate/tale_mate.py +++ b/src/talemate/tale_mate.py @@ -174,7 +174,6 @@ class Character: self.cover_image = asset_id def sheet_filtered(self, *exclude): - sheet = self.base_attributes or { "name": self.name, "gender": self.gender, @@ -188,59 +187,73 @@ class Character: sheet_list.append(f"{key}: {value}") return "\n".join(sheet_list) - - def random_dialogue_examples(self, scene:"Scene", num: int = 3, strip_name: bool = False, max_backlog: int = 250, max_length: int = 192) -> list[str]: + + def random_dialogue_examples( + self, + scene: "Scene", + num: int = 3, + strip_name: bool = False, + max_backlog: int = 250, + max_length: int = 192, + ) -> list[str]: """ Get multiple random example dialogue lines for this character. - + Will return up to `num` examples and not have any duplicates. """ - - history_examples = self._random_dialogue_examples_from_history(scene, num, max_backlog) - + + history_examples = self._random_dialogue_examples_from_history( + scene, num, max_backlog + ) + if len(history_examples) < num: - random_examples = self._random_dialogue_examples(num - len(history_examples), strip_name) - + random_examples = self._random_dialogue_examples( + num - len(history_examples), strip_name + ) + for example in random_examples: history_examples.append(example) - + # ensure sane example lengths - + history_examples = [ util.strip_partial_sentences(example[:max_length]) for example in history_examples ] - + log.debug("random_dialogue_examples", history_examples=history_examples) return history_examples - - def _random_dialogue_examples_from_history(self, scene:"Scene", num: int = 3, max_backlog: int = 250) -> list[str]: + + def _random_dialogue_examples_from_history( + self, scene: "Scene", num: int = 3, max_backlog: int = 250 + ) -> list[str]: """ Get multiple random example dialogue lines for this character from the scene's history. - + Will checks the last `max_backlog` messages in the scene's history and returns up to `num` examples. """ - + history = scene.history[-max_backlog:] - + examples = [] - + for message in history: if not isinstance(message, CharacterMessage): continue - + if message.character_name != self.name: continue - + examples.append(message.without_name.strip()) - + if not examples: return [] - - return random.sample(examples, min(num, len(examples))) - - def _random_dialogue_examples(self, num: int = 3, strip_name: bool = False) -> list[str]: + return random.sample(examples, min(num, len(examples))) + + def _random_dialogue_examples( + self, num: int = 3, strip_name: bool = False + ) -> list[str]: """ Get multiple random example dialogue lines for this character. @@ -259,7 +272,7 @@ class Character: random.shuffle(examples) # now pop examples until we have `num` examples or we run out of examples - + if strip_name: examples = [example.split(":", 1)[1].strip() for example in examples] @@ -387,9 +400,9 @@ class Character: }, } ) - + seen_attributes = set() - + for attr, value in self.base_attributes.items(): if attr.startswith("_"): continue @@ -412,11 +425,10 @@ class Character: ) for key, detail in self.details.items(): - # if colliding with attribute name, prefix with detail_ if key in seen_attributes: key = f"detail_{key}" - + items.append( { "text": f"{self.name} - {key}: {detail}", @@ -627,10 +639,12 @@ class Actor: def history(self): return self.scene.history + class Player(Actor): muted = 0 ai_controlled = 0 + class Scene(Emitter): """ A scene containing one ore more AI driven actors to interact with. @@ -695,13 +709,13 @@ class Scene(Emitter): self.Actor = Actor self.Player = Player self.Character = Character - + self.narrator_character_object = Character(name="__narrator__") self.active_pins = [] # Add an attribute to store the most recent AI Actor self.most_recent_ai_actor = None - + # if the user has requested to cancel the current action # or series of agent actions this will be true # @@ -781,15 +795,15 @@ class Scene(Emitter): """ if hasattr(self, "_save_files"): return self._save_files - + save_files = [] - + for file in os.listdir(self.save_dir): if file.endswith(".json"): save_files.append(file) - + self._save_files = sorted(save_files) - + return self._save_files @property @@ -828,15 +842,15 @@ class Scene(Emitter): @property def template_dir(self): return os.path.join(self.save_dir, "templates") - + @property def nodes_dir(self): return os.path.join(self.save_dir, "nodes") - + @property def info_dir(self): return os.path.join(self.save_dir, "info") - + @property def auto_save(self): return self.config.get("game", {}).get("general", {}).get("auto_save", True) @@ -855,19 +869,18 @@ class Scene(Emitter): @property def writing_style(self) -> world_state_templates.WritingStyle | None: - if not self.writing_style_template: return None - + try: group_uid, template_uid = self.writing_style_template.split("__", 1) return self._world_state_templates.find_template(group_uid, template_uid) except ValueError: return None - + @property def max_backscroll(self): - return self.config.get("game", {}).get("general", {}).get("max_backscroll", 512) + return self.config.get("game", {}).get("general", {}).get("max_backscroll", 512) @property def nodes_filename(self): @@ -876,7 +889,7 @@ class Scene(Emitter): @nodes_filename.setter def nodes_filename(self, value: str): self._nodes_filename = value or "" - + @property def nodes_filepath(self) -> str: return os.path.join(self.nodes_dir, self.nodes_filename) @@ -884,11 +897,11 @@ class Scene(Emitter): @property def creative_nodes_filename(self): return self._creative_nodes_filename or "creative-loop.json" - + @creative_nodes_filename.setter def creative_nodes_filename(self, value: str): self._creative_nodes_filename = value or "" - + @property def creative_nodes_filepath(self) -> str: return os.path.join(self.nodes_dir, self.creative_nodes_filename) @@ -898,18 +911,15 @@ class Scene(Emitter): phase = self.intent_state.phase if not phase: return {} - + return { "name": self.intent_state.current_scene_type.name, "intent": phase.intent, } - + @property def active_node_graph(self): return getattr(self, "node_graph", getattr(self, "creative_node_graph", None)) - - def set_description(self, description: str): - self.description = description def set_intro(self, intro: str): self.intro = intro @@ -920,9 +930,6 @@ class Scene(Emitter): def set_title(self, title: str): self.title = title - def set_content_context(self, content_context: str): - self.context = content_context - def connect(self): """ connect scenes to signals @@ -1034,7 +1041,7 @@ class Scene(Emitter): max_iterations: int = None, reverse: bool = False, meta_hash: int = None, - **filters + **filters, ): """ Removes the last message from the history that matches the given typ and source @@ -1050,31 +1057,31 @@ class Scene(Emitter): for idx in iter_range: message = self.history[idx] - + if message.typ != typ: iterations += 1 continue - + if source is not None and message.source != source: iterations += 1 continue - + if meta_hash is not None and message.meta_hash != meta_hash: iterations += 1 continue - + # Apply additional filters valid = True for filter_name, filter_value in filters.items(): if getattr(message, filter_name, None) != filter_value: valid = False break - + if valid: to_remove.append(message) if not all: break - + iterations += 1 if max_iterations and iterations >= max_iterations: break @@ -1089,22 +1096,20 @@ class Scene(Emitter): iterations = 0 for idx in range(len(self.history) - 1, -1, -1): message: SceneMessage = self.history[idx] - + iterations += 1 if iterations >= max_iterations: return None - + if message.typ != typ: continue - + for filter_name, filter_value in filters.items(): if getattr(message, filter_name, None) != filter_value: continue - + return self.history[idx] - - def message_index(self, message_id: int) -> int: """ Returns the index of the given message in the history @@ -1130,19 +1135,19 @@ class Scene(Emitter): if isinstance(self.history[idx], CharacterMessage): if self.history[idx].source == "player": return self.history[idx] - + def last_message_of_type( - self, - typ: str | list[str], + self, + typ: str | list[str], source: str = None, - max_iterations: int = None, + max_iterations: int = None, stop_on_time_passage: bool = False, on_iterate: Callable = None, - **filters + **filters, ) -> SceneMessage | None: """ Returns the last message of the given type and source - + Arguments: - typ: str | list[str] - the type of message to find - source: str - the source of the message @@ -1152,44 +1157,43 @@ class Scene(Emitter): Keyword Arguments: Any additional keyword arguments will be used to filter the messages against their attributes """ - + if not isinstance(typ, list): typ = [typ] - + num_iterations = 0 - + for idx in range(len(self.history) - 1, -1, -1): - if max_iterations is not None and num_iterations >= max_iterations: return None - + message = self.history[idx] - + if on_iterate: on_iterate(message) - + if isinstance(message, TimePassageMessage) and stop_on_time_passage: return None - + num_iterations += 1 - + if message.typ not in typ or (source and message.source != source): continue - + valid = True - + for filter_name, filter_value in filters.items(): message_value = getattr(message, filter_name, None) if message_value != filter_value: valid = False break - + if valid: return message def collect_messages( - self, - typ: str | list[str] = None, + self, + typ: str | list[str] = None, source: str = None, max_iterations: int = 100, max_messages: int | None = None, @@ -1206,10 +1210,10 @@ class Scene(Emitter): messages = [] iterations = 0 collected = 0 - + if start_idx is None: start_idx = len(self.history) - 1 - + for idx in range(start_idx, -1, -1): message = self.history[idx] if (not typ or message.typ in typ) and ( @@ -1229,9 +1233,9 @@ class Scene(Emitter): return messages def snapshot( - self, - lines: int = 3, - ignore: list[str | SceneMessage] = None, + self, + lines: int = 3, + ignore: list[str | SceneMessage] = None, start: int = None, as_format: str = "movie_script", ) -> str: @@ -1240,18 +1244,24 @@ class Scene(Emitter): """ if not ignore: - ignore = [ReinforcementMessage, DirectorMessage, ContextInvestigationMessage] + ignore = [ + ReinforcementMessage, + DirectorMessage, + ContextInvestigationMessage, + ] else: # ignore me also be a list of message type strings (e.g. 'director') # convert to class types _ignore = [] for item in ignore: if isinstance(item, str): - _ignore.append(MESSAGE_TYPES.get(item)) + _ignore.append(MESSAGE_TYPES.get(item)) elif isinstance(item, SceneMessage): _ignore.append(item) else: - raise ValueError("ignore must be a list of strings or SceneMessage types") + raise ValueError( + "ignore must be a list of strings or SceneMessage types" + ) ignore = _ignore collected = [] @@ -1335,26 +1345,26 @@ class Scene(Emitter): if memory_helper: await actor.character.commit_to_memory(memory_helper.agent) - - async def remove_character(self, character: Character, purge_from_memory: bool = True): + async def remove_character( + self, character: Character, purge_from_memory: bool = True + ): """ Remove a character from the scene - + Class remove_actor if the character is active otherwise remove from inactive_characters. """ - + for actor in self.actors: if actor.character == character: await self.remove_actor(actor) - + if character.name in self.inactive_characters: del self.inactive_characters[character.name] - + if purge_from_memory: await character.purge_from_memory() - async def remove_actor(self, actor: Actor): """ Remove an actor from the scene @@ -1389,10 +1399,10 @@ class Scene(Emitter): if not character_name: return - + if character_name == "__narrator__": return self.narrator_character_object - + if character_name in self.inactive_characters: return self.inactive_characters[character_name] @@ -1408,7 +1418,7 @@ class Scene(Emitter): for actor in self.actors: if isinstance(actor, Player): return actor.character - + # No active player found, return the first NPC for actor in self.actors: return actor.character @@ -1430,7 +1440,9 @@ class Scene(Emitter): if actor.character.name.lower() in line.lower(): return actor.character - def parse_characters_from_text(self, text: str, exclude_active:bool=False) -> list[Character]: + def parse_characters_from_text( + self, text: str, exclude_active: bool = False + ) -> list[Character]: """ Parse characters from a block of text """ @@ -1444,7 +1456,7 @@ class Scene(Emitter): # use regex with word boundaries to match whole words if re.search(rf"\b{actor.character.name.lower()}\b", text): characters.append(actor.character) - + # inactive characters for character in self.inactive_characters.values(): if re.search(rf"\b{character.name.lower()}\b", text): @@ -1468,14 +1480,14 @@ class Scene(Emitter): """ self.description = description - def get_intro(self, intro:str = None) -> str: + def get_intro(self, intro: str = None) -> str: """ Returns the intro text of the scene """ - + if not intro: intro = self.intro - + try: player_name = self.get_player_character().name intro = intro.replace("{{user}}", player_name).replace( @@ -1483,9 +1495,9 @@ class Scene(Emitter): ) except AttributeError: intro = self.intro - + editor = self.get_helper("editor").agent - + if editor.fix_exposition_enabled and editor.fix_exposition_narrator: if '"' not in intro and "*" not in intro: intro = f"*{intro}*" @@ -1520,35 +1532,37 @@ class Scene(Emitter): return count - def context_history( - self, budget: int = 8192, **kwargs - ): + def context_history(self, budget: int = 8192, **kwargs): parts_context = [] parts_dialogue = [] budget_context = int(0.5 * budget) budget_dialogue = int(0.5 * budget) - + keep_director = kwargs.get("keep_director", False) keep_context_investigation = kwargs.get("keep_context_investigation", True) show_hidden = kwargs.get("show_hidden", False) conversation_format = self.conversation_format actor_direction_mode = self.get_helper("director").agent.actor_direction_mode - layered_history_enabled = self.get_helper("summarizer").agent.layered_history_enabled + layered_history_enabled = self.get_helper( + "summarizer" + ).agent.layered_history_enabled include_reinforcements = kwargs.get("include_reinforcements", True) assured_dialogue_num = kwargs.get("assured_dialogue_num", 5) - + chapter_labels = kwargs.get("chapter_labels", False) chapter_numbers = [] history_len = len(self.history) - # CONTEXT # collect context, ignore where end > len(history) - count - if not self.layered_history or not layered_history_enabled or not self.layered_history[0]: - + if ( + not self.layered_history + or not layered_history_enabled + or not self.layered_history[0] + ): # no layered history available for i in range(len(self.archived_history) - 1, -1, -1): @@ -1564,82 +1578,95 @@ class Scene(Emitter): ) text = f"{time_message}: {archive_history_entry['text']}" except Exception as e: - log.error("context_history", error=e, traceback=traceback.format_exc()) + log.error( + "context_history", error=e, traceback=traceback.format_exc() + ) text = archive_history_entry["text"] if count_tokens(parts_context) + count_tokens(text) > budget_context: break - + text = condensed(text) - + parts_context.insert(0, text) - + else: - # layered history available # start with the last layer and work backwards - + next_layer_start = None num_layers = len(self.layered_history) - + for i in range(len(self.layered_history) - 1, -1, -1): - - log.debug("context_history - layered history", i=i, next_layer_start=next_layer_start) - + log.debug( + "context_history - layered history", + i=i, + next_layer_start=next_layer_start, + ) + if not self.layered_history[i]: continue - + k = next_layer_start if next_layer_start is not None else 0 - - for layered_history_entry in self.layered_history[i][next_layer_start if next_layer_start is not None else 0:]: - + + for layered_history_entry in self.layered_history[i][ + next_layer_start if next_layer_start is not None else 0 : + ]: time_message_start = util.iso8601_diff_to_human( layered_history_entry["ts_start"], self.ts ) time_message_end = util.iso8601_diff_to_human( layered_history_entry["ts_end"], self.ts ) - + if time_message_start == time_message_end: time_message = time_message_start else: - time_message = f"Start:{time_message_start}, End:{time_message_end}" if time_message_start != time_message_end else time_message_start + time_message = ( + f"Start:{time_message_start}, End:{time_message_end}" + if time_message_start != time_message_end + else time_message_start + ) text = f"{time_message} {layered_history_entry['text']}" - + # prepend chapter labels if chapter_labels: chapter_number = f"{num_layers - i}.{k + 1}" text = f"### Chapter {chapter_number}\n{text}" chapter_numbers.append(chapter_number) - + parts_context.append(text) - + k += 1 - + next_layer_start = layered_history_entry["end"] + 1 - + # collect archived history entries that have not yet been # summarized to the layered history - base_layer_start = self.layered_history[0][-1]["end"] + 1 if self.layered_history[0] else None - + base_layer_start = ( + self.layered_history[0][-1]["end"] + 1 + if self.layered_history[0] + else None + ) + if base_layer_start is not None: i = 0 - + # if chapter labels have been appanded, we need to # open a new section for the current scene - + if chapter_labels: parts_context.append("### Current\n") - + for archive_history_entry in self.archived_history[base_layer_start:]: time_message = util.iso8601_diff_to_human( archive_history_entry["ts"], self.ts ) - + text = f"{time_message}: {archive_history_entry['text']}" - + text = condensed(text) - + parts_context.append(text) i += 1 @@ -1652,34 +1679,43 @@ class Scene(Emitter): # DIALOGUE try: - summarized_to = self.archived_history[-1]["end"] if self.archived_history else 0 + summarized_to = ( + self.archived_history[-1]["end"] if self.archived_history else 0 + ) except KeyError: # only static archived history entries exist (pre-entered history # that doesnt have start and end timestamps) summarized_to = 0 - - + # if summarized_to somehow is bigger than the length of the history # since we have no way to determine where they sync up just put as much of # the dialogue as possible if summarized_to and summarized_to >= history_len: - log.warning("context_history", message="summarized_to is greater than history length - may want to regenerate history") + log.warning( + "context_history", + message="summarized_to is greater than history length - may want to regenerate history", + ) summarized_to = 0 - - log.debug("context_history", summarized_to=summarized_to, history_len=history_len) - - dialogue_messages_collected = 0 - - #for message in self.history[summarized_to if summarized_to is not None else 0:]: + + log.debug( + "context_history", summarized_to=summarized_to, history_len=history_len + ) + + dialogue_messages_collected = 0 + + # for message in self.history[summarized_to if summarized_to is not None else 0:]: for i in range(len(self.history) - 1, -1, -1): message = self.history[i] - - if i < summarized_to and dialogue_messages_collected >= assured_dialogue_num: + + if ( + i < summarized_to + and dialogue_messages_collected >= assured_dialogue_num + ): break if message.hidden and not show_hidden: continue - + if isinstance(message, ReinforcementMessage) and not include_reinforcements: continue @@ -1692,30 +1728,32 @@ class Scene(Emitter): # TODO: we may want to include these in the future continue - elif isinstance(keep_director, str) and message.character_name != keep_director: + elif ( + isinstance(keep_director, str) + and message.character_name != keep_director + ): continue - - elif isinstance(message, ContextInvestigationMessage) and not keep_context_investigation: - continue + elif ( + isinstance(message, ContextInvestigationMessage) + and not keep_context_investigation + ): + continue if count_tokens(parts_dialogue) + count_tokens(message) > budget_dialogue: break - + parts_dialogue.insert( - 0, - message.as_format(conversation_format, mode=actor_direction_mode) + 0, message.as_format(conversation_format, mode=actor_direction_mode) ) - + if isinstance(message, CharacterMessage): dialogue_messages_collected += 1 - - + if count_tokens(parts_context) < 128: intro = self.get_intro() if intro: parts_context.insert(0, intro) - active_agent_ctx = active_agent.get() if active_agent_ctx: @@ -1801,7 +1839,9 @@ class Scene(Emitter): "scene_status", scene=self.name, scene_time=self.ts, - human_ts=util.iso8601_duration_to_human(self.ts, suffix="") if self.ts else None, + human_ts=util.iso8601_duration_to_human(self.ts, suffix="") + if self.ts + else None, saved=self.saved, ) @@ -1866,41 +1906,41 @@ class Scene(Emitter): # TODO: need to adjust archived_history ts as well # but removal also probably means the history needs to be regenerated # anyway. - + def fix_time(self): """ New implementation of sync_time that will fix time across the board using the base history as the sole source of truth. - + This means first identifying the time jumps in the base history by looking for TimePassageMessages and then applying those time jumps - + to the archived history and the layered history based on their start and end indexes. """ try: ts = self.ts self._fix_time() - except Exception as e: + except Exception: log.error("fix_time", exc=traceback.format_exc()) self.ts = ts - + def _fix_time(self): starting_time = "PT0S" - + for archived_entry in self.archived_history: if "ts" in archived_entry and "end" not in archived_entry: starting_time = archived_entry["ts"] elif "end" in archived_entry: break - + # store time jumps by index time_jumps = [] - + for idx, message in enumerate(self.history): if isinstance(message, TimePassageMessage): time_jumps.append((idx, message.ts)) - + # now make the timejumps cumulative, meaning that each time jump # will be the sum of all time jumps up to that point cumulative_time_jumps = [] @@ -1908,33 +1948,32 @@ class Scene(Emitter): for idx, ts_jump in time_jumps: ts = util.iso8601_add(ts, ts_jump) cumulative_time_jumps.append((idx, ts)) - + try: ending_time = cumulative_time_jumps[-1][1] except IndexError: # no time jumps found ending_time = starting_time self.ts = ending_time - return - + return + # apply time jumps to the archived history ts = starting_time for _, entry in enumerate(self.archived_history): - if "end" not in entry: continue - + # we need to find best_ts by comparing entry["end"] # index to time_jumps (find the closest time jump that is # smaller than entry["end"]) - + best_ts = None for jump_idx, jump_ts in cumulative_time_jumps: if jump_idx < entry["end"]: best_ts = jump_ts else: break - + if best_ts: entry["ts"] = best_ts ts = entry["ts"] @@ -1973,27 +2012,27 @@ class Scene(Emitter): _active_pins = await self.world_state_manager.get_pins(active=True) self.active_pins = list(_active_pins.pins.values()) - + async def ensure_memory_db(self): memory = self.get_helper("memory").agent if not memory.db: await memory.set_db() async def emit_history(self): - emit("clear_screen", "") + emit("clear_screen", "") # this is mostly to support character cards # we introduce the main character to all such characters, replacing # the {{ user }} placeholder for npc in self.npcs: if npc.introduce_main_character: npc.introduce_main_character(self.main_character.character) - + # emit intro - intro:str = self.get_intro() + intro: str = self.get_intro() self.narrator_message(intro) - + # emit history - for message in self.history[-self.max_backscroll:]: + for message in self.history[-self.max_backscroll :]: if isinstance(message, CharacterMessage): character = self.get_character(message.character_name) else: @@ -2014,15 +2053,19 @@ class Scene(Emitter): while True: try: log.debug(f"Starting scene loop: {self.environment}") - + self.world_state.emit() - + if self.environment == "creative": - self.creative_node_graph, _ = load_graph(self.creative_nodes_filename, [self.save_dir]) + self.creative_node_graph, _ = load_graph( + self.creative_nodes_filename, [self.save_dir] + ) await initialize_packages(self, self.creative_node_graph) await self._run_creative_loop(init=first_loop) else: - self.node_graph, _ = load_graph(self.nodes_filename, [self.save_dir]) + self.node_graph, _ = load_graph( + self.nodes_filename, [self.save_dir] + ) await initialize_packages(self, self.node_graph) await self._run_game_loop(init=first_loop) except ExitScene: @@ -2037,21 +2080,20 @@ class Scene(Emitter): await asyncio.sleep(0.01) async def _game_startup(self): - self.commands = command = commands.Manager(self) - + self.commands = commands.Manager(self) + await self.signals["scene_init"].send( events.SceneStateEvent(scene=self, event_type="scene_init") ) - - - async def _run_game_loop(self, init: bool = True, node_graph = None): + + async def _run_game_loop(self, init: bool = True, node_graph=None): if init: await self._game_startup() await self.emit_history() - + self.nodegraph_state = state = GraphState() state.data["continue_scene"] = True - + while state.data["continue_scene"] and self.active: try: await self.node_graph.execute(state) @@ -2089,13 +2131,12 @@ class Scene(Emitter): ) emit("system", status="error", message=f"Unhandled Error: {e}") - async def _run_creative_loop(self, init: bool = True): await self.emit_history() - + self.nodegraph_state = state = GraphState() state.data["continue_scene"] = True - + while state.data["continue_scene"] and self.active: try: await self.creative_node_graph.execute(state) @@ -2144,7 +2185,6 @@ class Scene(Emitter): """ Saves the scene data, conversation history, archived history, and characters to a json file. """ - scene = self if self.immutable_save and not save_as and not force: save_as = True @@ -2188,7 +2228,7 @@ class Scene(Emitter): # Create a dictionary to store the scene data scene_data = self.serialize - + if not auto: emit("status", status="success", message="Saved scene") @@ -2205,14 +2245,14 @@ class Scene(Emitter): # add this scene to recent scenes in config await self.add_to_recent_scenes() - async def save_restore(self, filename:str): + async def save_restore(self, filename: str): """ Serializes the scene to a file. - + immutable_save will be set to True memory_sesion_id will be randomized """ - + serialized = self.serialize serialized["immutable_save"] = True serialized["memory_session_id"] = str(uuid.uuid4())[:10] @@ -2275,7 +2315,7 @@ class Scene(Emitter): self.set_new_memory_session_id() - async def restore(self, save_as:str | None=None): + async def restore(self, save_as: str | None = None): try: self.log.info("Restoring", source=self.restore_from) @@ -2296,22 +2336,22 @@ class Scene(Emitter): os.path.join(self.save_dir, self.restore_from), self.get_helper("conversation").agent.client, ) - + await self.reset_memory() - + if save_as: self.restore_from = restore_from await self.save(save_as=True, copy_name=save_as) else: self.filename = None self.emit_status(restored=True) - + interaction_state = interaction.get() - + if interaction_state: # Break and restart the game loop interaction_state.reset_requested = True - + except Exception as e: self.log.error("restore", error=e, traceback=traceback.format_exc()) @@ -2359,11 +2399,10 @@ class Scene(Emitter): def json(self): return json.dumps(self.serialize, indent=2, cls=save.SceneEncoder) - def interrupt(self): self.cancel_requested = True def continue_actions(self): if self.cancel_requested: self.cancel_requested = False - raise GenerationCancelled("action cancelled") \ No newline at end of file + raise GenerationCancelled("action cancelled") diff --git a/src/talemate/thematic_generators.py b/src/talemate/thematic_generators.py index 36242551..c7b29aa6 100644 --- a/src/talemate/thematic_generators.py +++ b/src/talemate/thematic_generators.py @@ -850,6 +850,7 @@ scifi_tropes = [ actor_name_colors = COLORS.copy() + class ThematicGenerator: def __init__(self, seed: int = None): self.seed = seed diff --git a/src/talemate/util/__init__.py b/src/talemate/util/__init__.py index e2080a90..2c46a9d3 100644 --- a/src/talemate/util/__init__.py +++ b/src/talemate/util/__init__.py @@ -4,19 +4,20 @@ import structlog import tiktoken from talemate.scene_message import SceneMessage -from talemate.util.dialogue import * -from talemate.util.prompt import * -from talemate.util.response import * -from talemate.util.image import * -from talemate.util.time import * -from talemate.util.dedupe import * -from talemate.util.data import * -from talemate.util.colors import * +from talemate.util.dialogue import * # noqa: F403, F401 +from talemate.util.prompt import * # noqa: F403, F401 +from talemate.util.response import * # noqa: F403, F401 +from talemate.util.image import * # noqa: F403, F401 +from talemate.util.time import * # noqa: F403, F401 +from talemate.util.dedupe import * # noqa: F403, F401 +from talemate.util.data import * # noqa: F403, F401 +from talemate.util.colors import * # noqa: F403, F401 log = structlog.get_logger("talemate.util") TIKTOKEN_ENCODING = tiktoken.encoding_for_model("gpt-4-turbo") + def count_tokens(source): if isinstance(source, list): t = 0 @@ -53,6 +54,3 @@ def clean_id(name: str) -> str: cleaned_name = re.sub(r"[^a-zA-Z0-9_\- ]", "", name) return cleaned_name - - - diff --git a/src/talemate/util/async_tools.py b/src/talemate/util/async_tools.py index f27cde98..6a81db9b 100644 --- a/src/talemate/util/async_tools.py +++ b/src/talemate/util/async_tools.py @@ -4,95 +4,101 @@ from functools import wraps from typing import Optional __all__ = [ - 'cleanup_pending_tasks', - 'debounce', - 'shared_debounce', + "cleanup_pending_tasks", + "debounce", + "shared_debounce", ] log = structlog.get_logger("talemate.util.async_tools") TASKS = {} + def throttle(delay: float): """ Ensures the decorated function is only called once every `delay` seconds. - + Unlike debounce which will delay the function until the last call, throttle will ensure the function is called at most once every `delay` seconds. """ - + def decorator(fn): last_called = 0 - + @wraps(fn) async def wrapper(*args, **kwargs): nonlocal last_called - + now = asyncio.get_event_loop().time() if now - last_called > delay: last_called = now return await fn(*args, **kwargs) - - return wrapper - return decorator + return wrapper + + return decorator def debounce(delay: float): """ Decorator to debounce a coroutine function. """ + def decorator(fn): task: Optional[asyncio.Task] = None - + @wraps(fn) async def wrapper(*args, **kwargs): nonlocal task - + # Cancel any existing task if task and not task.done(): task.cancel() - + # Create new delayed task async def delayed(): await asyncio.sleep(delay) return await fn(*args, **kwargs) - + asyncio.create_task(delayed()) - + return wrapper + return decorator -def shared_debounce(delay: float, task_key: str = "default", tasks: dict = None, immediate: bool = True): + +def shared_debounce( + delay: float, task_key: str = "default", tasks: dict = None, immediate: bool = True +): """ Decorator to debounce a coroutine function, but share the task across multiple calls. - + This allows you to debounce a function across multiple calls, so that only one task is running at a time. """ - + if not tasks: tasks = TASKS - + def decorator(fn): @wraps(fn) async def wrapper(*args, **kwargs): loop = asyncio.get_running_loop() - + is_first = True - + if task_key not in tasks: tasks[task_key] = None - + if tasks[task_key] and not tasks[task_key].done(): try: tasks[task_key].cancel() except RuntimeError as exc: log.error("shared_debounce: Error cancelling task", exc=exc) is_first = False - + if is_first and immediate: await fn(*args, **kwargs) - + async def delayed(): try: await asyncio.sleep(delay) @@ -100,26 +106,30 @@ def shared_debounce(delay: float, task_key: str = "default", tasks: dict = None, await fn(*args, **kwargs) except asyncio.CancelledError: pass - - + # Create and store the task, but attach it to the loop task = loop.create_task(delayed()) tasks[task_key] = task - + return task # Return the task but don't await it - + return wrapper + return decorator + async def cleanup_pending_tasks(): # Get all tasks from the current loop - pending = [task for task in asyncio.all_tasks() - if not task.done() and task is not asyncio.current_task()] - + pending = [ + task + for task in asyncio.all_tasks() + if not task.done() and task is not asyncio.current_task() + ] + # Cancel them for task in pending: task.cancel() - + # Wait for them to finish if pending: - await asyncio.gather(*pending, return_exceptions=True) \ No newline at end of file + await asyncio.gather(*pending, return_exceptions=True) diff --git a/src/talemate/util/colors.py b/src/talemate/util/colors.py index be70969a..c202fc3d 100644 --- a/src/talemate/util/colors.py +++ b/src/talemate/util/colors.py @@ -29,7 +29,6 @@ COLOR_MAP = { "brown": "#795548", "blue-grey": "#607D8B", "grey": "#9E9E9E", - # Lighten-3 colors "red-lighten-3": "#EF9A9A", "pink-lighten-3": "#F48FB1", @@ -50,7 +49,6 @@ COLOR_MAP = { "brown-lighten-3": "#BCAAA4", "blue-grey-lighten-3": "#B0BEC5", "grey-lighten-3": "#EEEEEE", - # Darken-3 colors "red-darken-3": "#C62828", "pink-darken-3": "#AD1457", @@ -77,5 +75,6 @@ COLOR_MAP = { COLOR_NAMES = sorted(list(COLOR_MAP.keys())) COLORS = sorted(list(COLOR_MAP.values())) + def random_color() -> str: - return random.choice(COLORS) \ No newline at end of file + return random.choice(COLORS) diff --git a/src/talemate/util/data.py b/src/talemate/util/data.py index 2d2c0541..42313114 100644 --- a/src/talemate/util/data.py +++ b/src/talemate/util/data.py @@ -1,28 +1,29 @@ import json import re -import json import structlog import yaml from datetime import date, datetime __all__ = [ "fix_faulty_json", - 'extract_data', + "extract_data", "extract_json", "extract_json_v2", "extract_yaml_v2", - 'JSONEncoder', - 'DataParsingError', + "JSONEncoder", + "DataParsingError", "fix_yaml_colon_in_strings", "fix_faulty_yaml", ] log = structlog.get_logger("talemate.util.dedupe") + class JSONEncoder(json.JSONEncoder): """ Default to str() on unknown types """ + def default(self, obj): try: if isinstance(obj, (date, datetime)): @@ -30,17 +31,19 @@ class JSONEncoder(json.JSONEncoder): return super().default(obj) except TypeError: return str(obj) - + class DataParsingError(Exception): """ Custom error class for data parsing errors (JSON, YAML, etc). """ + def __init__(self, message, data=None): self.message = message self.data = data super().__init__(self.message) + def fix_faulty_json(data: str) -> str: # Fix missing commas data = re.sub(r"}\s*{", "},{", data) @@ -123,48 +126,49 @@ def extract_json(s): json_object = json.loads(json_string) return json_string, json_object + def extract_json_v2(text): """ Extracts JSON structures from code blocks in a text string. - + Parameters: text (str): The input text containing code blocks with JSON. - + Returns: list: A list of unique parsed JSON objects. - + Raises: DataParsingError: If invalid JSON is encountered in code blocks. """ unique_jsons = [] seen = set() - + # Split by code block markers parts = text.split("```") - + # Process every code block (odd indices after split) for i in range(1, len(parts), 2): if i >= len(parts): break - + block = parts[i].strip() - + # Skip empty blocks if not block: continue - + # Remove language identifier if present if block.startswith("json"): block = block[4:].strip() - + # Try to parse the block as a single JSON object first try: fixed_block = fix_faulty_json(block) json_obj = json.loads(fixed_block) - + # Convert to string for deduplication check json_str = json.dumps(json_obj, sort_keys=True) - + # Only add if we haven't seen this object before if json_str not in seen: seen.add(json_str) @@ -174,21 +178,21 @@ def extract_json_v2(text): try: # Add commas between adjacent objects if needed fixed_block = fix_faulty_json(block) - + # Check for multiple JSON objects by looking for patterns like }{ or }[ # Replace with },{ or },[ fixed_block = re.sub(r"}\s*{", "},{", fixed_block) fixed_block = re.sub(r"]\s*{", "],[", fixed_block) - fixed_block = re.sub(r"}\s*\[", "},[" , fixed_block) - fixed_block = re.sub(r"]\s*\[", "],[" , fixed_block) - + fixed_block = re.sub(r"}\s*\[", "},[", fixed_block) + fixed_block = re.sub(r"]\s*\[", "],[", fixed_block) + # Wrap in array brackets if not already an array - if not (fixed_block.startswith('[') and fixed_block.endswith(']')): + if not (fixed_block.startswith("[") and fixed_block.endswith("]")): fixed_block = "[" + fixed_block + "]" - + # Parse as array json_array = json.loads(fixed_block) - + # Process each object in the array for json_obj in json_array: json_str = json.dumps(json_obj, sort_keys=True) @@ -197,43 +201,44 @@ def extract_json_v2(text): unique_jsons.append(json_obj) except json.JSONDecodeError as e: raise DataParsingError(f"Invalid JSON in code block: {str(e)}", block) - + return unique_jsons + def extract_yaml_v2(text): """ Extracts YAML structures from code blocks in a text string. - + Parameters: text (str): The input text containing code blocks with YAML. - + Returns: list: A list of unique parsed YAML objects. - + Raises: DataParsingError: If invalid YAML is encountered in code blocks. """ unique_yamls = [] seen = set() - + # Split by code block markers parts = text.split("```") - + # Process every code block (odd indices after split) for i in range(1, len(parts), 2): if i >= len(parts): break - + block = parts[i].strip() - + # Skip empty blocks if not block: continue - + # Remove language identifier if present if block.startswith("yaml") or block.startswith("yml"): - block = block[block.find("\n"):].strip() - + block = block[block.find("\n") :].strip() + # Parse YAML (supporting multiple documents with ---) try: # First try to parse the YAML as-is @@ -243,104 +248,114 @@ def extract_yaml_v2(text): try: # Apply fixes to YAML before parsing fixed_block = fix_faulty_yaml(block) - + # Use safe_load_all to get all YAML documents in the block yaml_docs = list(yaml.safe_load_all(fixed_block)) - except yaml.YAMLError as e2: + except yaml.YAMLError: # If it still fails, raise the original error raise DataParsingError(f"Invalid YAML in code block: {str(e)}", block) - + # If we only have one document and it's a dict, check if we should split it into multiple documents if len(yaml_docs) == 1 and isinstance(yaml_docs[0], dict) and yaml_docs[0]: # Check if the document has a nested structure where first level keys represent separate documents root_doc = yaml_docs[0] - + # If the first level keys all have dict values, treat them as separate documents if all(isinstance(root_doc[key], dict) for key in root_doc): # Replace yaml_docs with separate documents yaml_docs = [root_doc[key] for key in root_doc] - + # Process each YAML document for yaml_obj in yaml_docs: # Skip if None (empty YAML) if yaml_obj is None: continue - + # Convert to JSON string for deduplication check json_str = json.dumps(yaml_obj, sort_keys=True, cls=JSONEncoder) - + # Only add if we haven't seen this object before if json_str not in seen: seen.add(json_str) unique_yamls.append(yaml_obj) - + return unique_yamls def fix_yaml_colon_in_strings(yaml_text): """ Fixes YAML issues with unquoted strings containing colons. - + Parameters: yaml_text (str): The input YAML text to fix - + Returns: str: Fixed YAML text """ # Split the YAML text into lines - lines = yaml_text.split('\n') + lines = yaml_text.split("\n") result_lines = [] - + for line in lines: # Look for lines with key-value pairs where value has a colon - if ':' in line and line.count(':') > 1: + if ":" in line and line.count(":") > 1: # Check if this is a list item with a colon - list_item_match = re.match(r'^(\s*)-\s+(.+)$', line) + list_item_match = re.match(r"^(\s*)-\s+(.+)$", line) if list_item_match: indent, content = list_item_match.groups() - if ':' in content and not (content.startswith('"') or content.startswith("'") or - content.startswith('>') or content.startswith('|')): + if ":" in content and not ( + content.startswith('"') + or content.startswith("'") + or content.startswith(">") + or content.startswith("|") + ): # Convert to block scalar notation for list item result_lines.append(f"{indent}- |-") # Add the content indented on the next line result_lines.append(f"{indent} {content}") continue - + # Check if this looks like a key: value line with an unquoted value containing a colon - key_match = re.match(r'^(\s*)([^:]+):\s+(.+)$', line) + key_match = re.match(r"^(\s*)([^:]+):\s+(.+)$", line) if key_match: indent, key, value = key_match.groups() # If value has a colon and isn't already properly quoted/formatted - if ':' in value and not (value.startswith('"') or value.startswith("'") or - value.startswith('>') or value.startswith('|')): + if ":" in value and not ( + value.startswith('"') + or value.startswith("'") + or value.startswith(">") + or value.startswith("|") + ): # Convert to block scalar notation result_lines.append(f"{indent}{key}: |-") # Add the value indented on the next line result_lines.append(f"{indent} {value}") continue - + # If no processing needed, keep the original line result_lines.append(line) - - return '\n'.join(result_lines) + + return "\n".join(result_lines) + def fix_faulty_yaml(yaml_text): """ Fixes common YAML syntax issues by applying a series of fixers. - + Parameters: yaml_text (str): The input YAML text to fix - + Returns: str: Fixed YAML text """ # Apply specific fixers in sequence fixed_text = fix_yaml_colon_in_strings(yaml_text) - + # Add more fixers here as needed - + return fixed_text + def extract_data(text, schema_format: str = "json"): """ Extracts data from text based on the schema format. @@ -350,4 +365,4 @@ def extract_data(text, schema_format: str = "json"): elif schema_format == "yaml": return extract_yaml_v2(text) else: - raise ValueError(f"Unsupported schema format: {schema_format}") \ No newline at end of file + raise ValueError(f"Unsupported schema format: {schema_format}") diff --git a/src/talemate/util/dedupe.py b/src/talemate/util/dedupe.py index 1ab47f6a..c1e909a0 100644 --- a/src/talemate/util/dedupe.py +++ b/src/talemate/util/dedupe.py @@ -2,8 +2,9 @@ from nltk.tokenize import sent_tokenize from thefuzz import fuzz import structlog import pydantic -import re # Add import for regex +import re # Add import for regex from typing import Callable + __all__ = [ "similarity_score", "similarity_matches", @@ -15,7 +16,8 @@ __all__ = [ log = structlog.get_logger("talemate.util.dedupe") -SPECIAL_MARKERS = ['*', '"'] +SPECIAL_MARKERS = ["*", '"'] + class SimilarityMatch(pydantic.BaseModel): original: str @@ -23,21 +25,22 @@ class SimilarityMatch(pydantic.BaseModel): similarity: float left_neighbor: str | None = None right_neighbor: str | None = None - + def ln_startswith(self, marker: str) -> bool: return self.left_neighbor and self.left_neighbor.startswith(marker) - + def rn_startswith(self, marker: str) -> bool: return self.right_neighbor and self.right_neighbor.startswith(marker) - + def __hash__(self) -> int: return hash(self.original) - + def __eq__(self, other): if not isinstance(other, SimilarityMatch): return False return self.original == other.original + def similarity_score( line: str, lines: list[str], similarity_threshold: int = 95 ) -> tuple[bool, int, str]: @@ -70,19 +73,20 @@ def similarity_score( def compile_text_to_sentences(text: str) -> list[tuple[str, str]]: """ Compile text into sentences. - + Returns a list of tuples were the first element is the original sentence and the second element is the prepared sentence that will be used for similarity comparison. """ - sentences = sent_tokenize(text) - + sentences = sent_tokenize(text) + results = [] - + for sentence in sentences: results.append((sentence, sentence.strip("".join(SPECIAL_MARKERS)))) - + return results + def split_sentences_on_comma(sentences: list[str]) -> list[str]: """ Split sentences on commas. @@ -93,21 +97,22 @@ def split_sentences_on_comma(sentences: list[str]) -> list[str]: results.append(part.strip()) return results + def similarity_matches( - text_a: str, - text_b: str, - similarity_threshold: int = 95, + text_a: str, + text_b: str, + similarity_threshold: int = 95, min_length: int | None = None, - split_on_comma: bool = False + split_on_comma: bool = False, ) -> list[SimilarityMatch]: """ Returns a list of similarity matches between two texts. - + Arguments: text_a (str): The first text. text_b (str): The second text. similarity_threshold (int): The similarity threshold to use when comparing sentences. - min_length (int): The minimum length of a sentence to be considered for deduplication. + min_length (int): The minimum length of a sentence to be considered for deduplication. Shorter sentences are skipped. If None, all sentences are considered. split_on_comma (bool): Whether to split sentences on commas. When true if the whole sentence does NOT trigger a similarity match, the sentence will be split on commas and each comma will be checked for similarity. @@ -115,7 +120,7 @@ def similarity_matches( Returns: list: A list of similarity matches. """ - + sentences_a = compile_text_to_sentences(text_a) sentences_b = compile_text_to_sentences(text_b) @@ -123,8 +128,8 @@ def similarity_matches( left_neighbor = None right_neighbor = None for idx, (sentence_a, sentence_a_prepared) in enumerate(sentences_a): - left_neighbor = sentences_a[idx-1][0] if idx > 0 else None - right_neighbor = sentences_a[idx+1][0] if idx < len(sentences_a)-1 else None + left_neighbor = sentences_a[idx - 1][0] if idx > 0 else None + right_neighbor = sentences_a[idx + 1][0] if idx < len(sentences_a) - 1 else None if min_length and len(sentence_a) < min_length: continue for sentence_b, sentence_b_prepared in sentences_b: @@ -135,14 +140,14 @@ def similarity_matches( matches.append( SimilarityMatch( original=sentence_a, - matched=sentence_b, - similarity=similarity, + matched=sentence_b, + similarity=similarity, left_neighbor=left_neighbor, - right_neighbor=right_neighbor + right_neighbor=right_neighbor, ) ) break - + if split_on_comma: prev_comma_a = None parts_a = sentence_a.split(",") @@ -158,25 +163,27 @@ def similarity_matches( matches.append( SimilarityMatch( original=comma_a, - matched=comma_b, - similarity=similarity, + matched=comma_b, + similarity=similarity, left_neighbor=prev_comma_a, - right_neighbor=parts_a[idx_a+1] if idx_a < len(parts_a)-1 else None + right_neighbor=parts_a[idx_a + 1] + if idx_a < len(parts_a) - 1 + else None, ) ) break - + return matches + def dedupe_sentences( - text_a: str, text_b: str, similarity_threshold: int = 95, debug: bool = False, on_dedupe: Callable | None = None, min_length: int | None = None, - split_on_comma: bool = False + split_on_comma: bool = False, ) -> str: """ Will split both texts into sentences and then compare each sentence in text_a @@ -196,14 +203,23 @@ def dedupe_sentences( Returns: str: the cleaned text_a. """ - + # find similarity matches - matches = similarity_matches(text_a, text_b, similarity_threshold, min_length, split_on_comma) - - return dedupe_sentences_from_matches(text_a, matches, on_dedupe=on_dedupe, debug=debug) - - -def dedupe_sentences_from_matches(text_a: str, matches: list[SimilarityMatch], on_dedupe: Callable | None = None, debug: bool = False) -> str: + matches = similarity_matches( + text_a, text_b, similarity_threshold, min_length, split_on_comma + ) + + return dedupe_sentences_from_matches( + text_a, matches, on_dedupe=on_dedupe, debug=debug + ) + + +def dedupe_sentences_from_matches( + text_a: str, + matches: list[SimilarityMatch], + on_dedupe: Callable | None = None, + debug: bool = False, +) -> str: """ Dedupe sentences using fuzzy matching. """ @@ -213,30 +229,31 @@ def dedupe_sentences_from_matches(text_a: str, matches: list[SimilarityMatch], o for match in matches: replace = "" original = match.original - + # handle special markers (asterisks and quotes) for special_marker in SPECIAL_MARKERS: - # we are looking for markers at the end or beginning of the sentence # at an odd number of occurences # # those mean the sentence is part of a markup and the symbol # must be carried over to the replacement so the markup remains # complete - part_of_marker = original.startswith(special_marker) or original.endswith(special_marker) - + part_of_marker = original.startswith(special_marker) or original.endswith( + special_marker + ) + if not part_of_marker: continue - + # if not odd number of special markers, skip # an even number means the markup is fully contained within the sentence if original.count(special_marker) % 2 == 0: continue - + # if the sentence is part of a markup, we need to carry over the marker # to the replacement so the markup remains complete replace = special_marker - + # balancing logic - some edge cases to handle issues # with whitespace and special markers if original.startswith(special_marker): @@ -248,22 +265,32 @@ def dedupe_sentences_from_matches(text_a: str, matches: list[SimilarityMatch], o elif original.endswith(special_marker): if match.rn_startswith(special_marker): original = f"{original} " - + match_both = None match_left = None match_right = None - + # handle whitespace between neighbors if match.left_neighbor and match.right_neighbor: - pattern_both = re.escape(match.left_neighbor) + r'(\s+)' + re.escape(original) + r'(\s+)' + re.escape(match.right_neighbor) + pattern_both = ( + re.escape(match.left_neighbor) + + r"(\s+)" + + re.escape(original) + + r"(\s+)" + + re.escape(match.right_neighbor) + ) match_both = re.search(pattern_both, text_a) if match.left_neighbor: - pattern_left = re.escape(match.left_neighbor) + r'(\s+)' + re.escape(original) + pattern_left = ( + re.escape(match.left_neighbor) + r"(\s+)" + re.escape(original) + ) match_left = re.search(pattern_left, text_a) if match.right_neighbor: - pattern_right = re.escape(original) + r'(\s+)' + re.escape(match.right_neighbor) + pattern_right = ( + re.escape(original) + r"(\s+)" + re.escape(match.right_neighbor) + ) match_right = re.search(pattern_right, text_a) - + if match.left_neighbor and match.right_neighbor and match_both: whitespace = match_both.group(1) original = f"{whitespace}{original}" @@ -273,11 +300,11 @@ def dedupe_sentences_from_matches(text_a: str, matches: list[SimilarityMatch], o elif match.right_neighbor and match_right: whitespace = match_right.group(1) original = f"{original}{whitespace}" - + # Dedupe the original sentence by replacing it with the replacement # which is either an empty string or a special marker (* or ") text_a = text_a.replace(original, replace) - + # call the on_dedupe callback if it is provided if on_dedupe: on_dedupe(match) @@ -292,8 +319,7 @@ def dedupe_sentences_from_matches(text_a: str, matches: list[SimilarityMatch], o # final clean up for special_marker in SPECIAL_MARKERS: # idential markers with a space between can just be joined. - text_a = text_a.replace(f'{special_marker} {special_marker}', " ") - + text_a = text_a.replace(f"{special_marker} {special_marker}", " ") return text_a.strip() @@ -318,35 +344,35 @@ def dedupe_string( deduped = [] current_in_codeblock = False existing_in_codeblock = False - + for line in reversed(lines): stripped_line = line.strip() - + # Check for code block markers in current line if stripped_line.startswith("```"): current_in_codeblock = not current_in_codeblock deduped.append(line) continue - + # Skip deduping for lines in code blocks if current_in_codeblock: deduped.append(line) continue - + if len(stripped_line) > min_length: similar_found = False existing_in_codeblock = False - + for existing_line in deduped: # Track code block state for existing lines if existing_line.strip().startswith("```"): existing_in_codeblock = not existing_in_codeblock continue - - # Skip comparing if either line is in a code block + + # Skip comparing if either line is in a code block if existing_in_codeblock: continue - + similarity = fuzz.ratio(stripped_line, existing_line.strip()) if similarity >= similarity_threshold: similar_found = True @@ -364,4 +390,3 @@ def dedupe_string( deduped.append(line) return "\n".join(reversed(deduped)) - diff --git a/src/talemate/util/dialogue.py b/src/talemate/util/dialogue.py index 399c35e7..bed08833 100644 --- a/src/talemate/util/dialogue.py +++ b/src/talemate/util/dialogue.py @@ -34,32 +34,35 @@ def handle_endofline_special_delimiter(content: str) -> str: return content -def remove_trailing_markers(content: str, pair_markers:list[str] = None, enclosure_markers:list[str] = None) -> str: +def remove_trailing_markers( + content: str, pair_markers: list[str] = None, enclosure_markers: list[str] = None +) -> str: """ Will check for uneven balance in the specified markers and remove the trailing ones """ - + if not pair_markers: - pair_markers = ['"', '*'] - + pair_markers = ['"', "*"] + if not enclosure_markers: - enclosure_markers = ['(', '[', '{'] - + enclosure_markers = ["(", "[", "{"] + content = content.rstrip() - + for marker in pair_markers: if content.count(marker) % 2 == 1 and content.endswith(marker): content = content[:-1] content = content.rstrip() - + for marker in enclosure_markers: if content.endswith(marker): content = content[:-1] content = content.rstrip() - + return content.rstrip() + def parse_messages_from_str(string: str, names: list[str]) -> list[str]: """ Given a big string containing raw chat history, this function attempts to @@ -90,6 +93,7 @@ def parse_messages_from_str(string: str, names: list[str]) -> list[str]: return messages + def strip_partial_sentences(text: str) -> str: """ Removes any unfinished sentences from the end of the input text. @@ -112,6 +116,7 @@ def strip_partial_sentences(text: str) -> str: return text + def clean_message(message: str) -> str: message = message.strip() message = re.sub(r" +", " ", message) @@ -119,14 +124,12 @@ def clean_message(message: str) -> str: def clean_dialogue(dialogue: str, main_name: str) -> str: - cleaned = [] if not dialogue.startswith(main_name): dialogue = f"{main_name}: {dialogue}" for line in dialogue.split("\n"): - if not cleaned: cleaned.append(line) continue @@ -167,7 +170,9 @@ def replace_exposition_markers(s: str) -> str: return s -def ensure_dialog_format(line: str, talking_character: str = None, formatting:str = "md") -> str: +def ensure_dialog_format( + line: str, talking_character: str = None, formatting: str = "md" +) -> str: # if "*" not in line and '"' not in line: # if talking_character: # line = line[len(talking_character)+1:].lstrip() @@ -175,17 +180,16 @@ def ensure_dialog_format(line: str, talking_character: str = None, formatting:st # return f"\"{line}\"" # - if talking_character: line = line[len(talking_character) + 1 :].lstrip() eval_line = line.strip() - - if eval_line.startswith('*') and eval_line.endswith('*'): + + if eval_line.startswith("*") and eval_line.endswith("*"): if line.count("*") == 2 and not line.count('"'): return f"{talking_character}: {line}" if talking_character else line if eval_line.startswith('"') and eval_line.endswith('"'): - if line.count('"') == 2 and not line.count('*'): + if line.count('"') == 2 and not line.count("*"): return f"{talking_character}: {line}" if talking_character else line lines = [] @@ -395,32 +399,32 @@ def split_anchor_text(text: str, anchor_length: int = 10) -> tuple[str, str]: """ Splits input text into two parts: non-anchor and anchor. The anchor is the last `anchor_length` words of the text. - + Args: text (str): The input text to be split anchor_length (int): Number of words to use as anchor - + Returns: tuple[str, str]: A tuple containing (non_anchor, anchor) """ if not text: return "", "" - + # Split the input into words words = text.split() - + # If it's just one word, put it in the anchor if len(words) == 1: return "", text - + # Get the anchor (last anchor_length words) if len(words) > anchor_length: - anchor = ' '.join(words[-anchor_length:]) - non_anchor = ' '.join(words[:-anchor_length]) + anchor = " ".join(words[-anchor_length:]) + non_anchor = " ".join(words[:-anchor_length]) else: # For text with words <= anchor_length (but more than 1 word), split evenly mid_point = len(words) // 2 - non_anchor = ' '.join(words[:mid_point]) - anchor = ' '.join(words[mid_point:]) - - return non_anchor, anchor \ No newline at end of file + non_anchor = " ".join(words[:mid_point]) + anchor = " ".join(words[mid_point:]) + + return non_anchor, anchor diff --git a/src/talemate/util/diff.py b/src/talemate/util/diff.py index bcbb95c9..b0e041b7 100644 --- a/src/talemate/util/diff.py +++ b/src/talemate/util/diff.py @@ -2,22 +2,23 @@ from diff_match_patch import diff_match_patch __all__ = ["dmp_inline_diff"] -def dmp_inline_diff(text1:str, text2:str) -> str: + +def dmp_inline_diff(text1: str, text2: str) -> str: dmp = diff_match_patch() diffs = dmp.diff_main(text1, text2) dmp.diff_cleanupSemantic(diffs) - + delete_class = "diff-delete" insert_class = "diff-insert" - + html = [] for op, text in diffs: - text = text.replace('&', '&').replace('<', '<').replace('>', '>') + text = text.replace("&", "&").replace("<", "<").replace(">", ">") if op == 0: # Equal html.append(text) elif op == -1: # Delete html.append(f'{text}') elif op == 1: # Insert html.append(f'{text}') - - return ''.join(html) \ No newline at end of file + + return "".join(html) diff --git a/src/talemate/util/image.py b/src/talemate/util/image.py index 934631a9..6f617752 100644 --- a/src/talemate/util/image.py +++ b/src/talemate/util/image.py @@ -5,6 +5,7 @@ import struct import structlog from PIL import Image import json + log = structlog.get_logger("talemate.util.image") __all__ = [ @@ -27,7 +28,6 @@ def extract_metadata(img_path, img_format): return chara_read(img_path) - def read_metadata_from_png_text(image_path: str) -> dict: """ Reads the character metadata from the tEXt chunk of a PNG image. @@ -52,7 +52,6 @@ def read_metadata_from_png_text(image_path: str) -> dict: raise ValueError("No character metadata found.") - def chara_read(img_url, input_format=None): if input_format is None: if ".webp" in img_url: @@ -77,7 +76,7 @@ def chara_read(img_url, input_format=None): try: char_data = json.loads(description) - except: + except Exception: byte_arr = [int(x) for x in description.split(",")[1:]] uint8_array = bytearray(byte_arr) char_data_string = uint8_array.decode("utf-8") @@ -87,9 +86,8 @@ def chara_read(img_url, input_format=None): return False return char_data - except Exception as err: + except Exception: raise - return False elif format == "png": with Image.open(img_url) as img: @@ -121,4 +119,4 @@ def chara_read(img_url, input_format=None): ) return False else: - return None \ No newline at end of file + return None diff --git a/src/talemate/util/prompt.py b/src/talemate/util/prompt.py index 29331de3..0132e752 100644 --- a/src/talemate/util/prompt.py +++ b/src/talemate/util/prompt.py @@ -1,10 +1,6 @@ import re -__all__ = [ - "condensed", - "no_chapters", - "replace_special_tokens" -] +__all__ = ["condensed", "no_chapters", "replace_special_tokens"] def replace_special_tokens(prompt: str): @@ -27,25 +23,26 @@ def condensed(s): # also replace multiple spaces with a single space return re.sub(r"\s+", " ", r) + def no_chapters(text: str, replacement: str = "chapter") -> str: """ Takes a text that may contain mentions of 'Chapter X.Y' and replaces them with the provided replacement, maintaining the original casing pattern. - + Takes into account that the chapters may be in the format of: - + - Chapter X.Y -> Chapter - chapter X.Y -> chapter - CHAPTER X -> CHAPTER - ChapterX -> Chapter - + Args: text (str): The input text containing chapter references replacement (str): The text to replace chapter references with - + Returns: str: Text with chapter references replaced, maintaining casing - + Examples: >>> no_chapters("In Chapter 1.2 we see", "chapter") "In chapter we see" @@ -55,27 +52,27 @@ def no_chapters(text: str, replacement: str = "chapter") -> str: "chapter shows" """ import re - + def replace_with_case(match): original = match.group(0) - + # Check if the original is all uppercase if original.isupper(): return replacement.upper() - + # Check if the original starts with a capital letter if original[0].isupper(): return replacement.capitalize() - + # Default to lowercase return replacement.lower() - + # Pattern explanation: # (?i) - case insensitive flag # chapter\s* - matches "chapter" followed by optional whitespace # (?:\d+(?:\.\d+)?)? - optionally matches: # \d+ - one or more digits # (?:\.\d+)? - optionally followed by a decimal point and more digits - pattern = r'(?i)chapter\s*(?:\d+(?:\.\d+)?)?' - - return re.sub(pattern, replace_with_case, text) \ No newline at end of file + pattern = r"(?i)chapter\s*(?:\d+(?:\.\d+)?)?" + + return re.sub(pattern, replace_with_case, text) diff --git a/src/talemate/util/response.py b/src/talemate/util/response.py index 41338f54..7de6aa48 100644 --- a/src/talemate/util/response.py +++ b/src/talemate/util/response.py @@ -18,7 +18,7 @@ def extract_list(response: str) -> list: # Locate the beginning of the list lines = response.split("\n") - + # strip empty lines lines = [line for line in lines if line.strip() != ""] @@ -53,7 +53,6 @@ def extract_list(response: str) -> list: or re.match(r"^\* ", line) or re.match(r"^- ", line) ): - # strip the number or bullet line = re.sub(r"^(?:\d+\.|\*|-)", "", line).strip() diff --git a/src/talemate/util/time.py b/src/talemate/util/time.py index bfda20ff..8f241167 100644 --- a/src/talemate/util/time.py +++ b/src/talemate/util/time.py @@ -12,7 +12,7 @@ __all__ = [ "iso8601_diff_to_human", "iso8601_add", "iso8601_correct_duration", - "amount_unit_to_iso8601_duration" + "amount_unit_to_iso8601_duration", ] log = structlog.get_logger("talemate.util.time") @@ -33,6 +33,7 @@ UNIT_TO_ISO = { "years": "Y", } + def duration_to_timedelta(duration): """Convert an isodate.Duration object or a datetime.timedelta object to a datetime.timedelta object.""" # Check if the duration is already a timedelta object @@ -41,35 +42,37 @@ def duration_to_timedelta(duration): # If it's an isodate.Duration object with separate year, month, day, hour, minute, second attributes days = int(duration.years * 365 + duration.months * 30 + duration.days) - seconds = int(duration.tdelta.seconds if hasattr(duration, 'tdelta') else 0) + seconds = int(duration.tdelta.seconds if hasattr(duration, "tdelta") else 0) return datetime.timedelta(days=days, seconds=seconds) + def timedelta_to_duration(delta): """Convert a datetime.timedelta object to an isodate.Duration object.""" total_days = delta.days - + # Convert days back to years and months years = total_days // 365 remaining_days = total_days % 365 months = remaining_days // 30 days = remaining_days % 30 - + # Convert remaining seconds seconds = delta.seconds hours = seconds // 3600 seconds %= 3600 minutes = seconds // 60 seconds %= 60 - + return isodate.Duration( years=years, months=months, days=days, hours=hours, minutes=minutes, - seconds=seconds + seconds=seconds, ) + def parse_duration_to_isodate_duration(duration_str): """Parse ISO 8601 duration string and ensure the result is an isodate.Duration.""" parsed_duration = isodate.parse_duration(duration_str) @@ -96,49 +99,56 @@ def iso8601_diff(duration_str1, duration_str2): return difference -def flatten_duration_components(years: int, months: int, weeks: int, days: int, - hours: int, minutes: int, seconds: int): +def flatten_duration_components( + years: int, + months: int, + weeks: int, + days: int, + hours: int, + minutes: int, + seconds: int, +): """ Flatten duration components based on total duration following specific rules. Returns adjusted component values based on the total duration. """ - + total_days = years * 365 + months * 30 + weeks * 7 + days total_months = total_days // 30 - + # Less than 1 day - keep original granularity if total_days < 1: return years, months, weeks, days, hours, minutes, seconds - + # Less than 3 days - show only days and hours elif total_days < 3: if minutes >= 30: # Round up hours if 30+ minutes hours += 1 return 0, 0, 0, total_days, hours, 0, 0 - + # Less than a month - show only days elif total_days < 30: return 0, 0, 0, total_days, 0, 0, 0 - + # Less than 6 months - show months and days elif total_days < 180: new_months = total_days // 30 new_days = total_days % 30 return 0, new_months, 0, new_days, 0, 0, 0 - + # Less than 1 year - show only months elif total_months < 12: new_months = total_months if days > 15: # Round up months if 15+ days remain new_months += 1 return 0, new_months, 0, 0, 0, 0, 0 - + # Less than 3 years - show years and months elif total_months < 36: new_years = total_months // 12 new_months = total_months % 12 return new_years, new_months, 0, 0, 0, 0, 0 - + # More than 3 years - show only years else: # Derive the base number of years directly from total days to avoid cumulative @@ -163,8 +173,13 @@ def flatten_duration_components(years: int, months: int, weeks: int, days: int, return new_years, 0, 0, 0, 0, 0, 0 -def iso8601_duration_to_human(iso_duration, suffix: str = " ago", - zero_time_default: str = "Recently", flatten: bool = True): + +def iso8601_duration_to_human( + iso_duration, + suffix: str = " ago", + zero_time_default: str = "Recently", + flatten: bool = True, +): # Parse the ISO8601 duration string into an isodate duration object if not isinstance(iso_duration, isodate.Duration): duration = isodate.parse_duration(iso_duration) @@ -192,8 +207,10 @@ def iso8601_duration_to_human(iso_duration, suffix: str = " ago", # If flattening is requested, adjust the components if flatten: - years, months, weeks, days, hours, minutes, seconds = flatten_duration_components( - years, months, weeks, days, hours, minutes, seconds + years, months, weeks, days, hours, minutes, seconds = ( + flatten_duration_components( + years, months, weeks, days, hours, minutes, seconds + ) ) # Build the human-readable components @@ -295,6 +312,7 @@ def iso8601_correct_duration(duration: str) -> str: return corrected_duration + def amount_unit_to_iso8601_duration(amount: int, unit: str) -> str: """Converts numeric amount + textual unit into an ISO-8601 duration string. @@ -307,7 +325,9 @@ def amount_unit_to_iso8601_duration(amount: int, unit: str) -> str: unit_key = unit.lower().strip() if unit_key not in UNIT_TO_ISO: - raise ValueError(f"Invalid unit '{unit}'. Expected minutes, hours, days, weeks, months or years.") + raise ValueError( + f"Invalid unit '{unit}'. Expected minutes, hours, days, weeks, months or years." + ) code = UNIT_TO_ISO[unit_key] diff --git a/src/talemate/ux/schema.py b/src/talemate/ux/schema.py index d6684dca..c4cba71a 100644 --- a/src/talemate/ux/schema.py +++ b/src/talemate/ux/schema.py @@ -4,8 +4,9 @@ __all__ = [ "Note", ] + class Note(pydantic.BaseModel): text: str title: str = None color: str = "muted" - icon: str = "mdi-information-outline" \ No newline at end of file + icon: str = "mdi-information-outline" diff --git a/src/talemate/world_state/__init__.py b/src/talemate/world_state/__init__.py index b650cf64..20f4f26b 100644 --- a/src/talemate/world_state/__init__.py +++ b/src/talemate/world_state/__init__.py @@ -9,7 +9,7 @@ import talemate.instance as instance from talemate.emit import emit from talemate.prompts import Prompt from talemate.exceptions import GenerationCancelled -import talemate.game.focal.schema as focal_schema +import talemate.game.focal.schema as focal_schema ANY_CHARACTER = "__any_character__" @@ -75,11 +75,13 @@ class Suggestion(BaseModel): proposals: list[focal_schema.Call] = Field(default_factory=list) def remove_proposal(self, uid: str): - self.proposals = [proposal for proposal in self.proposals if proposal.uid != uid] + self.proposals = [ + proposal for proposal in self.proposals if proposal.uid != uid + ] - def merge(self, other:"Suggestion"): + def merge(self, other: "Suggestion"): assert self.id == other.id, "Suggestion ids must match" - + # loop through proposals, and override existing proposals if ids match # otherwise append the new proposal for proposal in other.proposals: @@ -89,7 +91,8 @@ class Suggestion(BaseModel): break else: self.proposals.append(proposal) - + + class WorldState(BaseModel): # characters in the scene by name characters: dict[str, CharacterState] = {} @@ -131,7 +134,7 @@ class WorldState(BaseModel): def add_character_name_mappings(self, *names): self.character_name_mappings.extend([name.lower() for name in names]) - + def normalize_name(self, name: str): """Normalizes item or character name away from variables style names @@ -222,12 +225,11 @@ class WorldState(BaseModel): return previous_characters = self.characters - previous_items = self.items scene = self.agent.scene character_names = scene.character_names self.characters = {} self.items = {} - + # if characters is not set or empty, make sure its at least a dict if not world_state.get("characters"): world_state["characters"] = {} @@ -328,7 +330,6 @@ class WorldState(BaseModel): """ memory = instance.get_agent("memory") - world_state = instance.get_agent("world_state") # first we check if any of the characters were refered # to with an alias @@ -512,18 +513,18 @@ class WorldState(BaseModel): # find all instances of the reinforcement in the scene history # and remove them - + reinforcement = self.reinforce[idx] - + self.agent.scene.pop_history( typ="reinforcement", character_name=reinforcement.character, question=reinforcement.question, all=True, ) - - #source = f"{self.reinforce[idx].question}:{self.reinforce[idx].character if self.reinforce[idx].character else ''}" - #self.agent.scene.pop_history(typ="reinforcement", source=source, all=True) + + # source = f"{self.reinforce[idx].question}:{self.reinforce[idx].character if self.reinforce[idx].character else ''}" + # self.agent.scene.pop_history(typ="reinforcement", source=source, all=True) self.reinforce.pop(idx) diff --git a/src/talemate/world_state/manager.py b/src/talemate/world_state/manager.py index 1947256a..b5f753a6 100644 --- a/src/talemate/world_state/manager.py +++ b/src/talemate/world_state/manager.py @@ -5,10 +5,14 @@ import structlog import talemate.world_state.templates as world_state_templates from talemate.character import activate_character, deactivate_character -from talemate.config import save_config from talemate.instance import get_agent from talemate.emit import emit -from talemate.world_state import ContextPin, InsertionMode, ManualContext, Reinforcement, Suggestion +from talemate.world_state import ( + ContextPin, + ManualContext, + Reinforcement, + Suggestion, +) if TYPE_CHECKING: from talemate.tale_mate import Character, Scene @@ -739,7 +743,6 @@ class WorldStateManager: run_immediately: bool = False, **kwargs, ) -> str: - if isinstance(template, str): template_uid = template template = self.template_collection.flat( @@ -760,7 +763,6 @@ class WorldStateManager: character_name: str, **kwargs, ) -> str: - if isinstance(template, str): template_uid = template template = self.template_collection.flat( @@ -838,7 +840,7 @@ class WorldStateManager: character="the character", ) tries -= 1 - + if not name: raise ValueError("Failed to generate a name for the character.") @@ -853,7 +855,7 @@ class WorldStateManager: ) # create character instance - character:"Character" = self.scene.Character( + character: "Character" = self.scene.Character( name=name, description=description, base_attributes={}, @@ -871,7 +873,7 @@ class WorldStateManager: actor = ActorCls(character, get_agent("conversation")) await self.scene.add_actor(actor) - + try: if generate_attributes: base_attributes = await world_state.extract_character_sheet( @@ -879,7 +881,6 @@ class WorldStateManager: ) character.update(base_attributes=base_attributes) - if not active: await deactivate_character(self.scene, name) except Exception as e: @@ -895,7 +896,6 @@ class WorldStateManager: intro: str | None = None, context: str | None = None, ) -> "Scene": - scene = self.scene scene.title = title scene.description = description @@ -911,43 +911,45 @@ class WorldStateManager: writing_style_template: str | None = None, restore_from: str | None = None, ) -> "Scene": - scene = self.scene scene.immutable_save = immutable_save scene.experimental = experimental scene.writing_style_template = writing_style_template - + if restore_from and restore_from not in scene.save_files: - raise ValueError(f"Restore file {restore_from} not found in scene save files.") - + raise ValueError( + f"Restore file {restore_from} not found in scene save files." + ) + scene.restore_from = restore_from return scene - # suggestions - + async def clear_suggestions(self): """ Clears all suggestions from the scene. """ self.scene.world_state.suggestions = [] self.scene.world_state.emit() - + async def add_suggestion(self, suggestion: Suggestion): """ Adds a suggestion to the scene. """ - - existing:Suggestion = await self.get_suggestion_by_id(suggestion.id) - - log.debug("WorldStateManager.add_suggestion", suggestion=suggestion, existing=existing) - + + existing: Suggestion = await self.get_suggestion_by_id(suggestion.id) + + log.debug( + "WorldStateManager.add_suggestion", suggestion=suggestion, existing=existing + ) + if existing: existing.merge(suggestion) else: self.scene.world_state.suggestions.append(suggestion) - + # changes will be emitted to the world editor as proposals for the character for proposal in suggestion.proposals: emit( @@ -959,52 +961,48 @@ class WorldStateManager: "suggestion_type": suggestion.type, "name": suggestion.name, "id": suggestion.id, - } + }, ) - + self.scene.world_state.emit() - - - async def get_suggestion_by_id(self, id:str) -> Suggestion: + + async def get_suggestion_by_id(self, id: str) -> Suggestion: """ Retrieves a suggestion from the scene by its id. """ - + for s in self.scene.world_state.suggestions: if s.id == id: return s - + self.scene.world_state.emit() - - - async def remove_suggestion(self, suggestion:str | Suggestion): + + async def remove_suggestion(self, suggestion: str | Suggestion): """ Removes a suggestion from the scene by its id. """ if isinstance(suggestion, str): suggestion = await self.get_suggestion_by_id(suggestion) - + if not suggestion: return - + self.scene.world_state.suggestions.remove(suggestion) self.scene.world_state.emit() - - - async def remove_suggestion_proposal(self, suggestion_id:str, proposal_uid:str): + + async def remove_suggestion_proposal(self, suggestion_id: str, proposal_uid: str): """ Removes a proposal from a suggestion by its uid. """ - - suggestion:Suggestion = await self.get_suggestion_by_id(suggestion_id) - + + suggestion: Suggestion = await self.get_suggestion_by_id(suggestion_id) + if not suggestion: return - + suggestion.remove_proposal(proposal_uid) - + # if suggestion is empty, remove it if not suggestion.proposals: await self.remove_suggestion(suggestion) self.scene.world_state.emit() - \ No newline at end of file diff --git a/src/talemate/world_state/templates/base.py b/src/talemate/world_state/templates/base.py index b78784f2..58b97964 100644 --- a/src/talemate/world_state/templates/base.py +++ b/src/talemate/world_state/templates/base.py @@ -1,7 +1,7 @@ import os import uuid from enum import IntEnum -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, TypeVar import pydantic import structlog @@ -36,7 +36,6 @@ TEMPLATE_PATH_TALEMATE = os.path.join(TEMPLATE_PATH, "talemate") class register: - def __init__(self, template_type: str): self.template_type = template_type @@ -127,69 +126,71 @@ class Group(pydantic.BaseModel): data = yaml.safe_load(f) data = cls.sanitize_data(data) return cls(path=path, **data) - + @classmethod def sanitize_data(cls, data: dict) -> dict: """ Sanitizes the data for the group. """ - + data.pop("path", None) - + # ensure uid is set if not data.get("uid"): data["uid"] = str(uuid.uuid4()) - + # if group name is null, set it to the group uid if not data.get("name"): uid = data.get("uid") log.warning("Group has no name", group_uid=uid) data["name"] = uid[:8] - + # if description or author are null, set them to blank strings - if data.get("description") is None: + if data.get("description") is None: data["description"] = "" if data.get("author") is None: data["author"] = "" - + # 1 remove null templates for template_id, template in list(data["templates"].items()): if not template: log.warning("Template is null", template_id=template_id) del data["templates"][template_id] - + # for templates with a null name, set it to the template_id for template_id, template in data["templates"].items(): if template.get("group") != data["uid"]: template["group"] = data["uid"] - + if not template.get("uid"): template["uid"] = template_id - + if not template.get("name"): log.warning("Template has no name", template_id=template_id) template["name"] = template_id[:8] - + # try to int priority, on failure set to 1 try: template["priority"] = int(template.get("priority", 1)) except (ValueError, TypeError): template["priority"] = 1 - # ensure template_type exists and drop any that are invalid for template_id, template in list(data["templates"].items()): template_type = template.get("template_type") if not template_type: log.warning("Template has no template_type", template_id=template_id) del data["templates"][template_id] - + if template_type not in MODELS: - log.warning("Template has invalid template_type", template_id=template_id, template_type=template_type) + log.warning( + "Template has invalid template_type", + template_id=template_id, + template_type=template_type, + ) del data["templates"][template_id] return data - @property def filename(self): @@ -197,7 +198,6 @@ class Group(pydantic.BaseModel): return f"{cleaned_name}.yaml" def save(self, path: str = TEMPLATE_PATH): - if not self.path: path = os.path.join(path, self.filename) else: @@ -223,7 +223,6 @@ class Group(pydantic.BaseModel): templates = {} for template_id, template in self.templates.items(): - # we need to ignore the value of `group` since that # is always going to be different. # @@ -249,7 +248,6 @@ class Group(pydantic.BaseModel): ) def insert_template(self, template: Template, save: bool = True): - if template.uid in self.templates: raise ValueError(f"Template with id {template.uid} already exists in group") @@ -259,7 +257,6 @@ class Group(pydantic.BaseModel): self.save() def update_template(self, template: Template, save: bool = True): - self.templates[template.uid] = template if save: @@ -339,7 +336,6 @@ class Collection(pydantic.BaseModel): config_templates = config.game.world_state.templates.model_dump() for template_type, templates in config_templates.items(): - name = f"legacy-{template_type.replace('_', '-')}s" if check_if_exists: @@ -387,7 +383,6 @@ class Collection(pydantic.BaseModel): for group in self.groups: for template_id, template in group.templates.items(): - if types and template.template_type not in types: continue @@ -395,7 +390,7 @@ class Collection(pydantic.BaseModel): templates[uid] = template return FlatCollection(templates=templates) - + def flat_by_template_uid_only(self) -> "FlatCollection": """ Returns a flat collection of templates by template uid only @@ -416,7 +411,6 @@ class Collection(pydantic.BaseModel): for group in self.groups: for template_id, template in group.templates.items(): - if types and template.template_type not in types: continue @@ -447,7 +441,7 @@ class Collection(pydantic.BaseModel): self.groups.remove(group) if save: group.delete() - + def collect_all(self, uids: list[str]) -> dict[str, AnnotatedTemplate]: """ Returns a dictionary of all templates in the collection @@ -469,7 +463,7 @@ class TypedCollection(pydantic.BaseModel): templates: dict[str, dict[str, AnnotatedTemplate]] = pydantic.Field( default_factory=dict ) - + def find_by_name(self, name: str) -> AnnotatedTemplate | None: for templates in self.templates.values(): for template in templates.values(): diff --git a/src/talemate/world_state/templates/character.py b/src/talemate/world_state/templates/character.py index da1ac52f..12856c4a 100644 --- a/src/talemate/world_state/templates/character.py +++ b/src/talemate/world_state/templates/character.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING import pydantic @@ -6,8 +6,6 @@ from talemate.instance import get_agent from talemate.world_state.templates.base import Template, log, register from talemate.world_state.templates.content import ( GenerationOptions, - Spices, - WritingStyle, ) if TYPE_CHECKING: @@ -41,7 +39,6 @@ class Attribute(Template): generation_options: GenerationOptions | None = None, **kwargs, ) -> GeneratedAttribute: - creator = get_agent("creator") character = scene.get_character(character_name) @@ -106,7 +103,6 @@ class Detail(Template): generation_options: GenerationOptions | None = None, **kwargs, ) -> GeneratedDetail: - creator = get_agent("creator") character = scene.get_character(character_name) diff --git a/src/talemate/world_state/templates/content.py b/src/talemate/world_state/templates/content.py index 01ebda47..33750689 100644 --- a/src/talemate/world_state/templates/content.py +++ b/src/talemate/world_state/templates/content.py @@ -3,17 +3,12 @@ from typing import TYPE_CHECKING, Literal import pydantic -from talemate.world_state.templates.base import Template, log, register +from talemate.world_state.templates.base import Template, register if TYPE_CHECKING: from talemate.tale_mate import Scene -__all__ = [ - "GenerationOptions", - "Spices", - "WritingStyle", - "PhraseDetection" -] +__all__ = ["GenerationOptions", "Spices", "WritingStyle", "PhraseDetection"] @register("spices") @@ -26,6 +21,7 @@ class Spices(Template): return self.formatted("instructions", scene, character_name, spice=spice) + class PhraseDetection(pydantic.BaseModel): phrase: str instructions: str @@ -33,7 +29,8 @@ class PhraseDetection(pydantic.BaseModel): classification: Literal["unwanted"] = "unwanted" match_method: Literal["regex", "semantic_similarity"] = "regex" active: bool = True - + + @register("writing_style") class WritingStyle(Template): description: str | None = None @@ -42,7 +39,8 @@ class WritingStyle(Template): def render(self, scene: "Scene", character_name: str): return self.formatted("instructions", scene, character_name) + class GenerationOptions(pydantic.BaseModel): spices: Spices | None = None spice_level: float = 0.0 - writing_style: WritingStyle | None = None \ No newline at end of file + writing_style: WritingStyle | None = None diff --git a/src/talemate/world_state/templates/scene.py b/src/talemate/world_state/templates/scene.py index c9569a2d..1f4a2b46 100644 --- a/src/talemate/world_state/templates/scene.py +++ b/src/talemate/world_state/templates/scene.py @@ -1,8 +1,6 @@ from typing import TYPE_CHECKING -import pydantic - -from talemate.world_state.templates.base import Template, log, register +from talemate.world_state.templates.base import Template, register if TYPE_CHECKING: from talemate.tale_mate import Scene @@ -14,10 +12,11 @@ __all__ = ["SceneType"] class SceneType(Template): """ Template for scene types. - + This template simply provides a way to store scene type definitions that can be directly applied to a scene without AI generation. """ + name: str description: str instructions: str | None = None @@ -26,23 +25,23 @@ class SceneType(Template): def to_scene_type_dict(self): """Convert the template to a scene type dictionary format""" scene_type_id = self.name.lower().replace(" ", "_") - + return { "id": scene_type_id, "name": self.name, "description": self.description, - "instructions": self.instructions + "instructions": self.instructions, } def apply_to_scene(self, scene: "Scene") -> dict: """ Apply this template to create a scene type in the scene - + Returns the created scene type dict """ scene_type = self.to_scene_type_dict() - + if scene and hasattr(scene, "scene_intent") and scene.scene_intent: scene.scene_intent.scene_types[scene_type["id"]] = scene_type - + return scene_type diff --git a/src/talemate/world_state/templates/state_reinforcement.py b/src/talemate/world_state/templates/state_reinforcement.py index 3be62ff9..01c4600b 100644 --- a/src/talemate/world_state/templates/state_reinforcement.py +++ b/src/talemate/world_state/templates/state_reinforcement.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypeVar, Union +from typing import TYPE_CHECKING import pydantic diff --git a/tests/test_dedupe.py b/tests/test_dedupe.py index ec05ebab..62f0b61e 100644 --- a/tests/test_dedupe.py +++ b/tests/test_dedupe.py @@ -1,375 +1,717 @@ import pytest from talemate.util.dedupe import dedupe_sentences, dedupe_string, similarity_matches + # Test cases for dedupe_sentences -@pytest.mark.parametrize("text_a, text_b, similarity_threshold, expected", [ - # Basic deduplication - ("This is a test sentence. Another sentence.", "This is a test sentence.", 95, "Another sentence."), - ("Sentence one. Sentence two.", "Sentence three. Sentence two.", 95, "Sentence one."), - # No deduplication - ("Unique sentence one. Unique sentence two.", "Different sentence one. Different sentence two.", 95, "Unique sentence one. Unique sentence two."), - # Threshold testing - ("Almost the same sentence.", "Almost the same sentence?", 99, "Almost the same sentence."), # Fixed: function keeps sentence at 99% threshold - ("Almost the same sentence.", "Almost the same sentence?", 100, "Almost the same sentence."), # Perfect match required - ("Slightly different text.", "Slightly different words.", 80, ""), # Lower threshold - # Empty inputs - ("", "Some sentence.", 95, ""), - ("Some sentence.", "", 95, "Some sentence."), - ("", "", 95, ""), - # Edge case: single sentences - ("Single sentence A.", "Single sentence A.", 95, ""), - ("Single sentence A.", "Single sentence B.", 95, "Single sentence A."), - # --- Quote handling tests --- - # Expect removal based on core match, accepting token removal issues - ('Some text. "First quote sentence. Second quote sentence needs removing." More text.', 'Second quote sentence needs removing.', 95, 'Some text. "First quote sentence." More text.'), - ('"Remove this first. Keep this second." The text continues.', 'Remove this first.', 95, '"Keep this second." The text continues.'), - ('The text starts here. "Keep this first. Remove this second."', 'Remove this second.', 95, 'The text starts here. "Keep this first."'), - ('"Sentence one. Sentence two to remove. Sentence three."', 'Sentence two to remove.', 95, '"Sentence one. Sentence three."'), - # --- Asterisk handling tests --- - ('Some text. *First asterisk sentence. Second asterisk sentence needs removing.* More text.', 'Second asterisk sentence needs removing.', 95, 'Some text. *First asterisk sentence.* More text.'), - ('*Remove this first. Keep this second.* The text continues.', 'Remove this first.', 95, '*Keep this second.* The text continues.'), - ('The text starts here. *Keep this first. Remove this second.*', 'Remove this second.', 95, 'The text starts here. *Keep this first.*'), - ('*Sentence one. Sentence two to remove. Sentence three.*', 'Sentence two to remove.', 95, '*Sentence one. Sentence three.*'), - # --- Mixed delimiter tests --- - ('Some text. *Asterisk text.* "Quote text." More text.', 'Quote text.', 90, 'Some text. *Asterisk text.* More text.'), - ('Some text. *Asterisk text.* "Quote text." More text.', 'Asterisk text.', 95, 'Some text. "Quote text." More text.'), - ('"Some text." *Asterisk text.* "Quote text." More text.', 'Asterisk text.', 95, '"Some text. Quote text." More text.'), -]) +@pytest.mark.parametrize( + "text_a, text_b, similarity_threshold, expected", + [ + # Basic deduplication + ( + "This is a test sentence. Another sentence.", + "This is a test sentence.", + 95, + "Another sentence.", + ), + ( + "Sentence one. Sentence two.", + "Sentence three. Sentence two.", + 95, + "Sentence one.", + ), + # No deduplication + ( + "Unique sentence one. Unique sentence two.", + "Different sentence one. Different sentence two.", + 95, + "Unique sentence one. Unique sentence two.", + ), + # Threshold testing + ( + "Almost the same sentence.", + "Almost the same sentence?", + 99, + "Almost the same sentence.", + ), # Fixed: function keeps sentence at 99% threshold + ( + "Almost the same sentence.", + "Almost the same sentence?", + 100, + "Almost the same sentence.", + ), # Perfect match required + ( + "Slightly different text.", + "Slightly different words.", + 80, + "", + ), # Lower threshold + # Empty inputs + ("", "Some sentence.", 95, ""), + ("Some sentence.", "", 95, "Some sentence."), + ("", "", 95, ""), + # Edge case: single sentences + ("Single sentence A.", "Single sentence A.", 95, ""), + ("Single sentence A.", "Single sentence B.", 95, "Single sentence A."), + # --- Quote handling tests --- + # Expect removal based on core match, accepting token removal issues + ( + 'Some text. "First quote sentence. Second quote sentence needs removing." More text.', + "Second quote sentence needs removing.", + 95, + 'Some text. "First quote sentence." More text.', + ), + ( + '"Remove this first. Keep this second." The text continues.', + "Remove this first.", + 95, + '"Keep this second." The text continues.', + ), + ( + 'The text starts here. "Keep this first. Remove this second."', + "Remove this second.", + 95, + 'The text starts here. "Keep this first."', + ), + ( + '"Sentence one. Sentence two to remove. Sentence three."', + "Sentence two to remove.", + 95, + '"Sentence one. Sentence three."', + ), + # --- Asterisk handling tests --- + ( + "Some text. *First asterisk sentence. Second asterisk sentence needs removing.* More text.", + "Second asterisk sentence needs removing.", + 95, + "Some text. *First asterisk sentence.* More text.", + ), + ( + "*Remove this first. Keep this second.* The text continues.", + "Remove this first.", + 95, + "*Keep this second.* The text continues.", + ), + ( + "The text starts here. *Keep this first. Remove this second.*", + "Remove this second.", + 95, + "The text starts here. *Keep this first.*", + ), + ( + "*Sentence one. Sentence two to remove. Sentence three.*", + "Sentence two to remove.", + 95, + "*Sentence one. Sentence three.*", + ), + # --- Mixed delimiter tests --- + ( + 'Some text. *Asterisk text.* "Quote text." More text.', + "Quote text.", + 90, + "Some text. *Asterisk text.* More text.", + ), + ( + 'Some text. *Asterisk text.* "Quote text." More text.', + "Asterisk text.", + 95, + 'Some text. "Quote text." More text.', + ), + ( + '"Some text." *Asterisk text.* "Quote text." More text.', + "Asterisk text.", + 95, + '"Some text. Quote text." More text.', + ), + ], +) def test_dedupe_sentences(text_a, text_b, similarity_threshold, expected): - assert dedupe_sentences(text_a, text_b, similarity_threshold=similarity_threshold) == expected + assert ( + dedupe_sentences(text_a, text_b, similarity_threshold=similarity_threshold) + == expected + ) + # Test cases for min_length parameter in dedupe_sentences -@pytest.mark.parametrize("text_a, text_b, min_length, similarity_threshold, expected", [ - # Basic min_length tests - Note: min_length applies to text_a sentences, not text_b - ("Short. This is a longer sentence.", "Short.", 10, 95, "Short. This is a longer sentence."), # "Short." sentence is skipped due to length - ("Short. This is a longer sentence.", "Short.", 4, 95, "This is a longer sentence."), # Short sentence above min_length is deduped - ("First short. Second short. Longer sentence here.", "First short.", 12, 95, "Second short. Longer sentence here."), # Only dedupe sentences above min_length - - # Edge cases - ("A B C. Longer text here.", "A B C.", 5, 95, "A B C. Longer text here."), # min_length affects dedupe check behavior, short sentence skipped in text_a - ("A B C. Longer text here.", "A B C.", 6, 95, "A B C. Longer text here."), # Just below min_length - - # Multiple sentences with varying lengths - ("Short1. Short2. Long sentence one. Long sentence two.", "Short1. Long sentence one.", 10, 95, "Short1. Short2. Long sentence two."), # Short sentences below min_length, longs are checked - ("Short1. Short2. Long sentence one. Long sentence two.", "Short1. Long sentence one.", 6, 95, "Short2. Long sentence two."), - - # Special delimiters with min_length (quotes) - ('"Short quote. Long quoted sentence." Text after.', "Short quote.", 10, 95, '"Long quoted sentence." Text after.'), # Inner content is what's deduped - ('"Short quote. Long quoted sentence." Text after.', "Short quote.", 5, 95, '"Long quoted sentence." Text after.'), # Short above min_length is deduped - - # Special delimiters with min_length (asterisks) - ('*Short text. Long sentence in asterisks.* Text after.', "Short text.", 10, 95, '*Long sentence in asterisks.* Text after.'), # Inner content is what's deduped - ('*Short text. Long sentence in asterisks.* Text after.', "Short text.", 5, 95, '*Long sentence in asterisks.* Text after.'), - - # Combined test cases - ("Apple. Orange. The orange is round. The car is fast.", "Apple. The car is fast.", 3, 95, "Orange. The orange is round."), # Both shorts and longs above min_length - ("Apple. Orange. The orange is round. The car is fast.", "Apple. The car is fast.", 7, 95, "Apple. Orange. The orange is round."), # Shorts below min_length, longs above -]) -def test_dedupe_sentences_min_length(text_a, text_b, min_length, similarity_threshold, expected): - assert dedupe_sentences(text_a, text_b, similarity_threshold=similarity_threshold, min_length=min_length) == expected +@pytest.mark.parametrize( + "text_a, text_b, min_length, similarity_threshold, expected", + [ + # Basic min_length tests - Note: min_length applies to text_a sentences, not text_b + ( + "Short. This is a longer sentence.", + "Short.", + 10, + 95, + "Short. This is a longer sentence.", + ), # "Short." sentence is skipped due to length + ( + "Short. This is a longer sentence.", + "Short.", + 4, + 95, + "This is a longer sentence.", + ), # Short sentence above min_length is deduped + ( + "First short. Second short. Longer sentence here.", + "First short.", + 12, + 95, + "Second short. Longer sentence here.", + ), # Only dedupe sentences above min_length + # Edge cases + ( + "A B C. Longer text here.", + "A B C.", + 5, + 95, + "A B C. Longer text here.", + ), # min_length affects dedupe check behavior, short sentence skipped in text_a + ( + "A B C. Longer text here.", + "A B C.", + 6, + 95, + "A B C. Longer text here.", + ), # Just below min_length + # Multiple sentences with varying lengths + ( + "Short1. Short2. Long sentence one. Long sentence two.", + "Short1. Long sentence one.", + 10, + 95, + "Short1. Short2. Long sentence two.", + ), # Short sentences below min_length, longs are checked + ( + "Short1. Short2. Long sentence one. Long sentence two.", + "Short1. Long sentence one.", + 6, + 95, + "Short2. Long sentence two.", + ), + # Special delimiters with min_length (quotes) + ( + '"Short quote. Long quoted sentence." Text after.', + "Short quote.", + 10, + 95, + '"Long quoted sentence." Text after.', + ), # Inner content is what's deduped + ( + '"Short quote. Long quoted sentence." Text after.', + "Short quote.", + 5, + 95, + '"Long quoted sentence." Text after.', + ), # Short above min_length is deduped + # Special delimiters with min_length (asterisks) + ( + "*Short text. Long sentence in asterisks.* Text after.", + "Short text.", + 10, + 95, + "*Long sentence in asterisks.* Text after.", + ), # Inner content is what's deduped + ( + "*Short text. Long sentence in asterisks.* Text after.", + "Short text.", + 5, + 95, + "*Long sentence in asterisks.* Text after.", + ), + # Combined test cases + ( + "Apple. Orange. The orange is round. The car is fast.", + "Apple. The car is fast.", + 3, + 95, + "Orange. The orange is round.", + ), # Both shorts and longs above min_length + ( + "Apple. Orange. The orange is round. The car is fast.", + "Apple. The car is fast.", + 7, + 95, + "Apple. Orange. The orange is round.", + ), # Shorts below min_length, longs above + ], +) +def test_dedupe_sentences_min_length( + text_a, text_b, min_length, similarity_threshold, expected +): + assert ( + dedupe_sentences( + text_a, + text_b, + similarity_threshold=similarity_threshold, + min_length=min_length, + ) + == expected + ) + # Test cases for newline preservation in dedupe_sentences -@pytest.mark.parametrize("text_a, text_b, similarity_threshold, expected", [ - # Basic newline preservation - ("The orange is round.\nThe car is fast.\n\nI wonder what today will bring.", "This is a long sentence.\n\nI wonder what today will bring.", 95, "The orange is round.\nThe car is fast."), - - # Basic single-line removal - ("Line 1.\nLine 2.\nLine 3.", "Line 2.", 95, "Line 1.\nLine 3."), - - # Paragraph preservation - ("First paragraph.\n\nSecond paragraph.", "First paragraph.", 95, "Second paragraph."), - ("Multi-line.\nAnother line.\nDuplicate.", "Another line.", 95, "Multi-line.\nDuplicate."), - - # Special delimiters with newlines - ('"Line 1.\nLine 2."', "Line 2.", 95, '"Line 1."'), - ("*Line A.\nLine B.\nLine C.*", "Line B.", 95, "*Line A.\nLine C.*"), - - # Complex cases with mixed newlines and delimiters - ("Text starts.\n\n*Inner text.\nDuplicate text.*\n\nText ends.", "Duplicate text.", 95, "Text starts.\n\n*Inner text.*\n\nText ends."), - - # Multiple paragraphs with sentence deduplication - ("Paragraph one.\nDuplicate sentence.\n\nParagraph two.", "Duplicate sentence.", 95, "Paragraph one.\n\nParagraph two."), - - # Consecutive newlines - ("Text before.\n\n\nSentence to keep.\n\nSentence to remove.", "Sentence to remove.", 95, "Text before.\n\n\nSentence to keep."), - - # Quoted text with multiple lines - ('First line.\n"Second line.\nThird line to remove.\nFourth line."', "Third line to remove.", 95, 'First line.\n"Second line.\nFourth line."'), - - # Edge cases with newlines at beginning/end - ("\nFirst line.\nDuplicate line.", "Duplicate line.", 95, "First line."), - ("First line.\nDuplicate line.\n", "Duplicate line.", 95, "First line."), - ("\nDuplicate line.\n", "Duplicate line.", 95, ""), - - # Multi-paragraph deduplication - ("Para 1.\n\nDuplicate para.\n\nPara 3.", "Duplicate para.", 95, "Para 1.\n\nPara 3."), - - # Combining with min_length (test implicitly, not through parameter) - ("Short.\nLonger line to remove.\nAnother short.", "Longer line to remove.", 95, "Short.\nAnother short."), - - # Complex document-like structure (similarity needs to be lower because sentences will contain the header text) - ("# Header\n\nIntro paragraph.\n\n## Section\n\nDuplicate content.\n\n### Subsection", "Duplicate content.", 75, "# Header\n\nIntro paragraph.\n\n### Subsection"), -]) +@pytest.mark.parametrize( + "text_a, text_b, similarity_threshold, expected", + [ + # Basic newline preservation + ( + "The orange is round.\nThe car is fast.\n\nI wonder what today will bring.", + "This is a long sentence.\n\nI wonder what today will bring.", + 95, + "The orange is round.\nThe car is fast.", + ), + # Basic single-line removal + ("Line 1.\nLine 2.\nLine 3.", "Line 2.", 95, "Line 1.\nLine 3."), + # Paragraph preservation + ( + "First paragraph.\n\nSecond paragraph.", + "First paragraph.", + 95, + "Second paragraph.", + ), + ( + "Multi-line.\nAnother line.\nDuplicate.", + "Another line.", + 95, + "Multi-line.\nDuplicate.", + ), + # Special delimiters with newlines + ('"Line 1.\nLine 2."', "Line 2.", 95, '"Line 1."'), + ("*Line A.\nLine B.\nLine C.*", "Line B.", 95, "*Line A.\nLine C.*"), + # Complex cases with mixed newlines and delimiters + ( + "Text starts.\n\n*Inner text.\nDuplicate text.*\n\nText ends.", + "Duplicate text.", + 95, + "Text starts.\n\n*Inner text.*\n\nText ends.", + ), + # Multiple paragraphs with sentence deduplication + ( + "Paragraph one.\nDuplicate sentence.\n\nParagraph two.", + "Duplicate sentence.", + 95, + "Paragraph one.\n\nParagraph two.", + ), + # Consecutive newlines + ( + "Text before.\n\n\nSentence to keep.\n\nSentence to remove.", + "Sentence to remove.", + 95, + "Text before.\n\n\nSentence to keep.", + ), + # Quoted text with multiple lines + ( + 'First line.\n"Second line.\nThird line to remove.\nFourth line."', + "Third line to remove.", + 95, + 'First line.\n"Second line.\nFourth line."', + ), + # Edge cases with newlines at beginning/end + ("\nFirst line.\nDuplicate line.", "Duplicate line.", 95, "First line."), + ("First line.\nDuplicate line.\n", "Duplicate line.", 95, "First line."), + ("\nDuplicate line.\n", "Duplicate line.", 95, ""), + # Multi-paragraph deduplication + ( + "Para 1.\n\nDuplicate para.\n\nPara 3.", + "Duplicate para.", + 95, + "Para 1.\n\nPara 3.", + ), + # Combining with min_length (test implicitly, not through parameter) + ( + "Short.\nLonger line to remove.\nAnother short.", + "Longer line to remove.", + 95, + "Short.\nAnother short.", + ), + # Complex document-like structure (similarity needs to be lower because sentences will contain the header text) + ( + "# Header\n\nIntro paragraph.\n\n## Section\n\nDuplicate content.\n\n### Subsection", + "Duplicate content.", + 75, + "# Header\n\nIntro paragraph.\n\n### Subsection", + ), + ], +) def test_dedupe_sentences_newlines(text_a, text_b, similarity_threshold, expected): - assert dedupe_sentences(text_a, text_b, similarity_threshold=similarity_threshold) == expected + assert ( + dedupe_sentences(text_a, text_b, similarity_threshold=similarity_threshold) + == expected + ) + # Test cases for dedupe_string -@pytest.mark.parametrize("s, min_length, similarity_threshold, expected", [ - # Basic deduplication - Note: dedupe_string processes lines from bottom up - ("Line 1\nLine 2\nLine 1", 5, 95, "Line 2\nLine 1"), # Fixed: preserves last occurrence - ("Duplicate line.\nAnother line.\nDuplicate line.", 10, 95, "Another line.\nDuplicate line."), # Fixed: reverse order - # No deduplication (different lines) - ("Line one.\nLine two.\nLine three.", 5, 95, "Line one.\nLine two.\nLine three."), - # min_length testing - ("Short line\nAnother short line\nShort line", 15, 95, "Short line\nAnother short line\nShort line"), # Below min_length - ("This is a long line.\nThis is another long line.\nThis is a long line.", 10, 95, "This is another long line.\nThis is a long line."), # Fixed: reversed order - # similarity_threshold testing - ("Very similar line number one.\nVery similar line number two.", 10, 90, "Very similar line number two."), # Fixed: keeps second line at 90% threshold - ("Very similar line number one.\nVery similar line number two.", 10, 98, "Very similar line number one.\nVery similar line number two."), - # Code block handling - ("Regular line 1\n```\nCode line 1\nCode line 1\n```\nRegular line 1", 5, 95, "```\nCode line 1\nCode line 1\n```\nRegular line 1"), # Fixed: code block processing - # Fix for failing test - updated to match actual function output - ("Line A\n```\nInside code\n```\nLine B\nLine A\n```\nInside code\n```", 5, 95, "```\nInside code\n```\nLine B\nLine A\n```\nInside code\n```"), - # Mixed short and long lines - ("Short\nThis is a longer line.\nAnother long line that is similar.\nShort\nThis is a longer line.", 5, 90, "Short\nAnother long line that is similar.\nShort\nThis is a longer line."), # Fixed: order preservation - # Empty input - ("", 5, 95, ""), - # Only short lines - ("a\nb\nc\na", 5, 95, "a\nb\nc\na"), # Fixed: below min_length so no deduplication - # Lines with only whitespace - ("Line 1\n \nLine 1", 5, 95, " \nLine 1"), # Fixed: whitespace line not detected as duplicate - ("Line X\n \nLine X", 0, 95, " \nLine X"), # Fixed: min_length 0 behavior - # Test case where duplicate is kept because the first occurrence is inside a code block - ("```\nThis is a duplicate line\n```\nSome other line\nThis is a duplicate line", 10, 95, "```\nThis is a duplicate line\n```\nSome other line\nThis is a duplicate line"), - # Fix for failing test - actual behavior preserves original content with code blocks - ("This is a duplicate line\nSome other line\n```\nThis is a duplicate line\n```", 10, 95, "This is a duplicate line\nSome other line\n```\nThis is a duplicate line\n```"), - # Test case where duplicate check might span across code blocks - ("Line Alpha\n```\nCode Block Content\n```\nLine Alpha", 5, 95, "```\nCode Block Content\n```\nLine Alpha") # Fixed: preserves bottom occurrence -]) +@pytest.mark.parametrize( + "s, min_length, similarity_threshold, expected", + [ + # Basic deduplication - Note: dedupe_string processes lines from bottom up + ( + "Line 1\nLine 2\nLine 1", + 5, + 95, + "Line 2\nLine 1", + ), # Fixed: preserves last occurrence + ( + "Duplicate line.\nAnother line.\nDuplicate line.", + 10, + 95, + "Another line.\nDuplicate line.", + ), # Fixed: reverse order + # No deduplication (different lines) + ( + "Line one.\nLine two.\nLine three.", + 5, + 95, + "Line one.\nLine two.\nLine three.", + ), + # min_length testing + ( + "Short line\nAnother short line\nShort line", + 15, + 95, + "Short line\nAnother short line\nShort line", + ), # Below min_length + ( + "This is a long line.\nThis is another long line.\nThis is a long line.", + 10, + 95, + "This is another long line.\nThis is a long line.", + ), # Fixed: reversed order + # similarity_threshold testing + ( + "Very similar line number one.\nVery similar line number two.", + 10, + 90, + "Very similar line number two.", + ), # Fixed: keeps second line at 90% threshold + ( + "Very similar line number one.\nVery similar line number two.", + 10, + 98, + "Very similar line number one.\nVery similar line number two.", + ), + # Code block handling + ( + "Regular line 1\n```\nCode line 1\nCode line 1\n```\nRegular line 1", + 5, + 95, + "```\nCode line 1\nCode line 1\n```\nRegular line 1", + ), # Fixed: code block processing + # Fix for failing test - updated to match actual function output + ( + "Line A\n```\nInside code\n```\nLine B\nLine A\n```\nInside code\n```", + 5, + 95, + "```\nInside code\n```\nLine B\nLine A\n```\nInside code\n```", + ), + # Mixed short and long lines + ( + "Short\nThis is a longer line.\nAnother long line that is similar.\nShort\nThis is a longer line.", + 5, + 90, + "Short\nAnother long line that is similar.\nShort\nThis is a longer line.", + ), # Fixed: order preservation + # Empty input + ("", 5, 95, ""), + # Only short lines + ( + "a\nb\nc\na", + 5, + 95, + "a\nb\nc\na", + ), # Fixed: below min_length so no deduplication + # Lines with only whitespace + ( + "Line 1\n \nLine 1", + 5, + 95, + " \nLine 1", + ), # Fixed: whitespace line not detected as duplicate + ("Line X\n \nLine X", 0, 95, " \nLine X"), # Fixed: min_length 0 behavior + # Test case where duplicate is kept because the first occurrence is inside a code block + ( + "```\nThis is a duplicate line\n```\nSome other line\nThis is a duplicate line", + 10, + 95, + "```\nThis is a duplicate line\n```\nSome other line\nThis is a duplicate line", + ), + # Fix for failing test - actual behavior preserves original content with code blocks + ( + "This is a duplicate line\nSome other line\n```\nThis is a duplicate line\n```", + 10, + 95, + "This is a duplicate line\nSome other line\n```\nThis is a duplicate line\n```", + ), + # Test case where duplicate check might span across code blocks + ( + "Line Alpha\n```\nCode Block Content\n```\nLine Alpha", + 5, + 95, + "```\nCode Block Content\n```\nLine Alpha", + ), # Fixed: preserves bottom occurrence + ], +) def test_dedupe_string(s, min_length, similarity_threshold, expected): - assert dedupe_string(s, min_length=min_length, similarity_threshold=similarity_threshold) == expected + assert ( + dedupe_string( + s, min_length=min_length, similarity_threshold=similarity_threshold + ) + == expected + ) + # Test cases for similarity_matches function -@pytest.mark.parametrize("text_a, text_b, similarity_threshold, min_length, split_on_comma, expected_count, check_properties", [ - # Basic matching - ( - "This is a test sentence. Another test sentence.", - "This is a test sentence.", - 95, None, False, - 1, - lambda matches: matches[0].original == "This is a test sentence." and matches[0].similarity >= 95 - ), - - # Multiple matches - ( - "First sentence. Second sentence. Third sentence.", - "First sentence. Third sentence.", - 95, None, False, - 2, - lambda matches: matches[0].original == "First sentence." and matches[1].original == "Third sentence." - ), - - # Similarity threshold testing - ( - "Almost identical sentence.", - "Almost identical sentences.", - 90, None, False, - 1, - lambda matches: matches[0].similarity >= 90 - ), - ( - "Almost identical sentence.", - "Almost identical sentences.", - 99, None, False, - 0, - lambda matches: True # No matches expected - ), - - # min_length filtering - ( - "Short. This is a longer sentence.", - "Short. Different longer sentence.", - 95, 10, False, - 0, - lambda matches: True # Only "Short" would match but it's below min_length - ), - ( - "Short. This is a longer sentence.", - "Short. Different longer sentence.", - 95, 5, False, - 1, - lambda matches: matches[0].original == "Short." - ), - - # split_on_comma testing - ( - "Before comma, after comma.", - "Something else, after comma.", - 95, None, True, - 1, - lambda matches: "after comma" in matches[0].original - ), - ( - "Before comma, after comma.", - "Something else, after comma.", - 95, None, False, - 0, - lambda matches: True # Whole sentences don't match above threshold - ), - - # Special markers handling - note that the tokenizer splits sentences differently with special markers - ( - "*This has asterisks.* Regular text.", - "This has asterisks.", - 95, None, False, - 1, - lambda matches: matches[0].original == "*This has asterisks." - ), - ( - '"This has quotes." Regular text.', - "This has quotes.", - 95, None, False, - 1, - lambda matches: matches[0].original == '"This has quotes."' - ), - - # Neighbor detection - ( - "First neighbor. Middle sentence. Last neighbor.", - "Middle sentence.", - 95, None, False, - 1, - lambda matches: ( - matches[0].original == "Middle sentence." and - matches[0].left_neighbor == "First neighbor." and - matches[0].right_neighbor == "Last neighbor." - ) - ), - - # Edge cases - ( - "", - "Some text.", - 95, None, False, - 0, - lambda matches: True # Empty text_a should have no matches - ), - ( - "Some text.", - "", - 95, None, False, - 0, - lambda matches: True # Empty text_b should have no matches - ), - ( - "Single sentence.", - "Single sentence.", - 95, None, False, - 1, - lambda matches: matches[0].original == "Single sentence." and matches[0].similarity == 100 - ), -]) -def test_similarity_matches(text_a, text_b, similarity_threshold, min_length, split_on_comma, expected_count, check_properties): +@pytest.mark.parametrize( + "text_a, text_b, similarity_threshold, min_length, split_on_comma, expected_count, check_properties", + [ + # Basic matching + ( + "This is a test sentence. Another test sentence.", + "This is a test sentence.", + 95, + None, + False, + 1, + lambda matches: matches[0].original == "This is a test sentence." + and matches[0].similarity >= 95, + ), + # Multiple matches + ( + "First sentence. Second sentence. Third sentence.", + "First sentence. Third sentence.", + 95, + None, + False, + 2, + lambda matches: matches[0].original == "First sentence." + and matches[1].original == "Third sentence.", + ), + # Similarity threshold testing + ( + "Almost identical sentence.", + "Almost identical sentences.", + 90, + None, + False, + 1, + lambda matches: matches[0].similarity >= 90, + ), + ( + "Almost identical sentence.", + "Almost identical sentences.", + 99, + None, + False, + 0, + lambda matches: True, # No matches expected + ), + # min_length filtering + ( + "Short. This is a longer sentence.", + "Short. Different longer sentence.", + 95, + 10, + False, + 0, + lambda matches: True, # Only "Short" would match but it's below min_length + ), + ( + "Short. This is a longer sentence.", + "Short. Different longer sentence.", + 95, + 5, + False, + 1, + lambda matches: matches[0].original == "Short.", + ), + # split_on_comma testing + ( + "Before comma, after comma.", + "Something else, after comma.", + 95, + None, + True, + 1, + lambda matches: "after comma" in matches[0].original, + ), + ( + "Before comma, after comma.", + "Something else, after comma.", + 95, + None, + False, + 0, + lambda matches: True, # Whole sentences don't match above threshold + ), + # Special markers handling - note that the tokenizer splits sentences differently with special markers + ( + "*This has asterisks.* Regular text.", + "This has asterisks.", + 95, + None, + False, + 1, + lambda matches: matches[0].original == "*This has asterisks.", + ), + ( + '"This has quotes." Regular text.', + "This has quotes.", + 95, + None, + False, + 1, + lambda matches: matches[0].original == '"This has quotes."', + ), + # Neighbor detection + ( + "First neighbor. Middle sentence. Last neighbor.", + "Middle sentence.", + 95, + None, + False, + 1, + lambda matches: ( + matches[0].original == "Middle sentence." + and matches[0].left_neighbor == "First neighbor." + and matches[0].right_neighbor == "Last neighbor." + ), + ), + # Edge cases + ( + "", + "Some text.", + 95, + None, + False, + 0, + lambda matches: True, # Empty text_a should have no matches + ), + ( + "Some text.", + "", + 95, + None, + False, + 0, + lambda matches: True, # Empty text_b should have no matches + ), + ( + "Single sentence.", + "Single sentence.", + 95, + None, + False, + 1, + lambda matches: matches[0].original == "Single sentence." + and matches[0].similarity == 100, + ), + ], +) +def test_similarity_matches( + text_a, + text_b, + similarity_threshold, + min_length, + split_on_comma, + expected_count, + check_properties, +): matches = similarity_matches( - text_a, - text_b, + text_a, + text_b, similarity_threshold=similarity_threshold, min_length=min_length, - split_on_comma=split_on_comma + split_on_comma=split_on_comma, ) - + assert len(matches) == expected_count if expected_count > 0: assert check_properties(matches) + # Additional focused tests for specific behaviors def test_similarity_matches_with_min_length(): text_a = "Very short. This is a longer sentence that should be detected." text_b = "Very short. This is a longer sentence that should be matched." - + # With min_length that filters out the short sentence matches = similarity_matches(text_a, text_b, similarity_threshold=90, min_length=15) assert len(matches) == 1 assert "longer sentence" in matches[0].original - + # Without min_length, both sentences should match matches = similarity_matches(text_a, text_b, similarity_threshold=90) assert len(matches) == 2 assert "Very short" in matches[0].original assert "longer sentence" in matches[1].original + def test_similarity_matches_comma_splitting(): text_a = "First part, similar middle part, last part." text_b = "Different start, similar middle part, different end." - + # Without split_on_comma, no matches (whole sentences don't match enough) - matches = similarity_matches(text_a, text_b, similarity_threshold=95, split_on_comma=False) + matches = similarity_matches( + text_a, text_b, similarity_threshold=95, split_on_comma=False + ) assert len(matches) == 0 - + # With split_on_comma, the middle part should match - matches = similarity_matches(text_a, text_b, similarity_threshold=95, split_on_comma=True) + matches = similarity_matches( + text_a, text_b, similarity_threshold=95, split_on_comma=True + ) assert len(matches) == 1 assert "similar middle part" in matches[0].original + def test_similarity_matches_special_marker_handling(): # Test with both asterisks and quotes in the same text - text_a = "*Asterisk part.* Regular part. \"Quoted part.\"" + text_a = '*Asterisk part.* Regular part. "Quoted part."' text_b = "Asterisk part. Different text. Quoted part." - + matches = similarity_matches(text_a, text_b, similarity_threshold=90) assert len(matches) == 2 - + # Check that the special markers are preserved in the original but only at the beginning # due to how the tokenizer works asterisk_match = next((m for m in matches if "*" in m.original), None) - quote_match = next((m for m in matches if "\"" in m.original), None) - + quote_match = next((m for m in matches if '"' in m.original), None) + assert asterisk_match is not None assert quote_match is not None assert asterisk_match.original == "*Asterisk part." - assert quote_match.original == "\"Quoted part.\"" + assert quote_match.original == '"Quoted part."' + def test_similarity_matches_min_length_with_comma_splitting(): """Test that min_length is properly honored during split_on_comma operations.""" # Text with multiple comma-separated parts of varying lengths text_a = "Short, Medium length part, Very long and detailed part of the sentence." text_b = "Different, Medium length part, Another long and unrelated segment." - + # Should match "Medium length part" with split_on_comma=True and no min_length matches = similarity_matches( - text_a, text_b, - similarity_threshold=95, - split_on_comma=True + text_a, text_b, similarity_threshold=95, split_on_comma=True ) assert len(matches) == 1 assert "Medium length part" in matches[0].original - + # Should NOT match "Short" due to min_length=10, but still match "Medium length part" matches = similarity_matches( - text_a, text_b, - similarity_threshold=95, - min_length=10, - split_on_comma=True + text_a, text_b, similarity_threshold=95, min_length=10, split_on_comma=True ) assert len(matches) == 1 assert "Medium length part" in matches[0].original assert "Short" not in matches[0].original - + # With higher min_length, should still match the longer part matches = similarity_matches( - text_a, text_b, - similarity_threshold=95, - min_length=15, - split_on_comma=True + text_a, text_b, similarity_threshold=95, min_length=15, split_on_comma=True ) assert len(matches) == 1 assert "Medium length part" in matches[0].original - + # With very high min_length, should match nothing matches = similarity_matches( - text_a, text_b, - similarity_threshold=95, - min_length=30, - split_on_comma=True + text_a, text_b, similarity_threshold=95, min_length=30, split_on_comma=True ) assert len(matches) == 0 diff --git a/tests/test_dialogue_cleanup.py b/tests/test_dialogue_cleanup.py index 264f1365..b96ef63e 100644 --- a/tests/test_dialogue_cleanup.py +++ b/tests/test_dialogue_cleanup.py @@ -22,86 +22,177 @@ The second line. The third line.\" """ -@pytest.mark.parametrize("input, expected", [ - ('Hello how are you?', 'Hello how are you?'), - ('"Hello how are you?"', '"Hello how are you?"'), - ('"Hello how are you?" he asks "I am fine"', '"Hello how are you?" *he asks* "I am fine"'), - ('Hello how are you? *he asks* I am fine', '"Hello how are you?" *he asks* "I am fine"'), - - ('Hello how are you?" *he asks* I am fine', '"Hello how are you?" *he asks* "I am fine"'), - ('Hello how are you?" *he asks I am fine', '"Hello how are you?" *he asks I am fine*'), - ('Hello how are you?" *he asks* "I am fine" *', '"Hello how are you?" *he asks* "I am fine"'), - - ('"Hello how are you *he asks* I am fine"', '"Hello how are you" *he asks* "I am fine"'), - ('This is a string without any markers', 'This is a string without any markers'), - ('This is a string with an ending quote"', '"This is a string with an ending quote"'), - ('This is a string with an ending asterisk*', '*This is a string with an ending asterisk*'), - ('"Mixed markers*', '*Mixed markers*'), - ('*narrative.* dialogue" *more narrative.*', '*narrative.* "dialogue" *more narrative.*'), - ('"*messed up dialogue formatting.*" *some narration.*', '"messed up dialogue formatting." *some narration.*'), - ('*"messed up narration formatting."* "some dialogue."', '"messed up narration formatting." "some dialogue."'), - ('Some dialogue and two line-breaks right after, followed by narration.\n\n*Narration*', '"Some dialogue and two line-breaks right after, followed by narration."\n\n*Narration*'), - ('*Some narration with a "quoted" string in it.* Then some unquoted dialogue.\n\n*More narration.*', '*Some narration with a* "quoted" *string in it.* "Then some unquoted dialogue."\n\n*More narration.*'), - ('*Some narration* Some dialogue but not in quotes. *', '*Some narration* "Some dialogue but not in quotes."'), - ('*First line\nSecond line\nThird line*', '*First line\nSecond line\nThird line*'), - (MULTILINE_TEST_A_INPUT, MULTILINE_TEST_A_EXPECTED), -]) + +@pytest.mark.parametrize( + "input, expected", + [ + ("Hello how are you?", "Hello how are you?"), + ('"Hello how are you?"', '"Hello how are you?"'), + ( + '"Hello how are you?" he asks "I am fine"', + '"Hello how are you?" *he asks* "I am fine"', + ), + ( + "Hello how are you? *he asks* I am fine", + '"Hello how are you?" *he asks* "I am fine"', + ), + ( + 'Hello how are you?" *he asks* I am fine', + '"Hello how are you?" *he asks* "I am fine"', + ), + ( + 'Hello how are you?" *he asks I am fine', + '"Hello how are you?" *he asks I am fine*', + ), + ( + 'Hello how are you?" *he asks* "I am fine" *', + '"Hello how are you?" *he asks* "I am fine"', + ), + ( + '"Hello how are you *he asks* I am fine"', + '"Hello how are you" *he asks* "I am fine"', + ), + ( + "This is a string without any markers", + "This is a string without any markers", + ), + ( + 'This is a string with an ending quote"', + '"This is a string with an ending quote"', + ), + ( + "This is a string with an ending asterisk*", + "*This is a string with an ending asterisk*", + ), + ('"Mixed markers*', "*Mixed markers*"), + ( + '*narrative.* dialogue" *more narrative.*', + '*narrative.* "dialogue" *more narrative.*', + ), + ( + '"*messed up dialogue formatting.*" *some narration.*', + '"messed up dialogue formatting." *some narration.*', + ), + ( + '*"messed up narration formatting."* "some dialogue."', + '"messed up narration formatting." "some dialogue."', + ), + ( + "Some dialogue and two line-breaks right after, followed by narration.\n\n*Narration*", + '"Some dialogue and two line-breaks right after, followed by narration."\n\n*Narration*', + ), + ( + '*Some narration with a "quoted" string in it.* Then some unquoted dialogue.\n\n*More narration.*', + '*Some narration with a* "quoted" *string in it.* "Then some unquoted dialogue."\n\n*More narration.*', + ), + ( + "*Some narration* Some dialogue but not in quotes. *", + '*Some narration* "Some dialogue but not in quotes."', + ), + ( + "*First line\nSecond line\nThird line*", + "*First line\nSecond line\nThird line*", + ), + (MULTILINE_TEST_A_INPUT, MULTILINE_TEST_A_EXPECTED), + ], +) def test_dialogue_cleanup(input, expected): assert ensure_dialog_format(input) == expected - - -@pytest.mark.parametrize("input, expected, main_name", [ - ("bob: says a sentence", "bob: says a sentence", "bob"), - ("bob: says a sentence\nbob: says another sentence", "bob: says a sentence\nsays another sentence", "bob"), - ("bob: says a sentence with a colon: to explain something", "bob: says a sentence with a colon: to explain something", "bob"), - ("bob: i have a riddle for you, alice: the riddle", "bob: i have a riddle for you, alice: the riddle", "bob"), - ("bob: says something\nalice: says something else", "bob: says something", "bob"), - ("bob: says a sentence. then a", "bob: says a sentence.", "bob"), - ("bob: first paragraph\n\nsecond paragraph", "bob: first paragraph\n\nsecond paragraph", "bob"), - # movie script new speaker cutoff - ("bob: says a sentence\n\nALICE\nsays something else", "bob: says a sentence", "bob"), -]) + + +@pytest.mark.parametrize( + "input, expected, main_name", + [ + ("bob: says a sentence", "bob: says a sentence", "bob"), + ( + "bob: says a sentence\nbob: says another sentence", + "bob: says a sentence\nsays another sentence", + "bob", + ), + ( + "bob: says a sentence with a colon: to explain something", + "bob: says a sentence with a colon: to explain something", + "bob", + ), + ( + "bob: i have a riddle for you, alice: the riddle", + "bob: i have a riddle for you, alice: the riddle", + "bob", + ), + ( + "bob: says something\nalice: says something else", + "bob: says something", + "bob", + ), + ("bob: says a sentence. then a", "bob: says a sentence.", "bob"), + ( + "bob: first paragraph\n\nsecond paragraph", + "bob: first paragraph\n\nsecond paragraph", + "bob", + ), + # movie script new speaker cutoff + ( + "bob: says a sentence\n\nALICE\nsays something else", + "bob: says a sentence", + "bob", + ), + ], +) def test_clean_dialogue(input, expected, main_name): - others = ["alice", "charlie"] assert clean_dialogue(input, main_name) == expected - - -@pytest.mark.parametrize("input, expected", [ - ('Hello how are you? "', 'Hello how are you?'), - ('Hello how are you? *', 'Hello how are you?'), - ('Hello how are you? {', 'Hello how are you?'), - ('Hello how are you? [', 'Hello how are you?'), - ('Hello how are you? (', 'Hello how are you?'), - ('"Hello how are you?"', '"Hello how are you?"'), - ('"Hello how are you?" "', '"Hello how are you?"'), - ('"Hello how are you?" *', '"Hello how are you?"'), - ('"Hello how are you?" *"', '"Hello how are you?"'), - ('*He says* "Hello how are you?"', '*He says* "Hello how are you?"'), - ('*He says* "Hello how are you?" *', '*He says* "Hello how are you?"'), - ('*He says* "Hello how are you?" *"', '*He says* "Hello how are you?"'), - ('(Some thoughts)', '(Some thoughts)'), - ('(Some thoughts) ', '(Some thoughts)'), - ('(Some thoughts) (', '(Some thoughts)'), - ('(Some thoughts) [', '(Some thoughts)'), -]) + + +@pytest.mark.parametrize( + "input, expected", + [ + ('Hello how are you? "', "Hello how are you?"), + ("Hello how are you? *", "Hello how are you?"), + ("Hello how are you? {", "Hello how are you?"), + ("Hello how are you? [", "Hello how are you?"), + ("Hello how are you? (", "Hello how are you?"), + ('"Hello how are you?"', '"Hello how are you?"'), + ('"Hello how are you?" "', '"Hello how are you?"'), + ('"Hello how are you?" *', '"Hello how are you?"'), + ('"Hello how are you?" *"', '"Hello how are you?"'), + ('*He says* "Hello how are you?"', '*He says* "Hello how are you?"'), + ('*He says* "Hello how are you?" *', '*He says* "Hello how are you?"'), + ('*He says* "Hello how are you?" *"', '*He says* "Hello how are you?"'), + ("(Some thoughts)", "(Some thoughts)"), + ("(Some thoughts) ", "(Some thoughts)"), + ("(Some thoughts) (", "(Some thoughts)"), + ("(Some thoughts) [", "(Some thoughts)"), + ], +) def test_remove_trailing_markers(input, expected): assert remove_trailing_markers(input) == expected -@pytest.mark.parametrize("input, anchor_length, expected_non_anchor, expected_anchor", [ - ("", 10, "", ""), - ("Hello", 10, "", "Hello"), - ("This is a short example", 10, "This is", "a short example"), - ("One two three four", 4, "One two", "three four"), - ("This is a longer example with more than ten words to test the anchor functionality", 10, - "This is a longer example", "with more than ten words to test the anchor functionality"), - ("One two three four five six seven eight nine ten", 10, - "One two three four five", "six seven eight nine ten"), - ("Two words", 10, "Two", "words"), - ("One Two Three", 3, "One", "Two Three"), -]) +@pytest.mark.parametrize( + "input, anchor_length, expected_non_anchor, expected_anchor", + [ + ("", 10, "", ""), + ("Hello", 10, "", "Hello"), + ("This is a short example", 10, "This is", "a short example"), + ("One two three four", 4, "One two", "three four"), + ( + "This is a longer example with more than ten words to test the anchor functionality", + 10, + "This is a longer example", + "with more than ten words to test the anchor functionality", + ), + ( + "One two three four five six seven eight nine ten", + 10, + "One two three four five", + "six seven eight nine ten", + ), + ("Two words", 10, "Two", "words"), + ("One Two Three", 3, "One", "Two Three"), + ], +) def test_split_anchor_text(input, anchor_length, expected_non_anchor, expected_anchor): from talemate.util.dialogue import split_anchor_text + non_anchor, anchor = split_anchor_text(input, anchor_length) assert non_anchor == expected_non_anchor - assert anchor == expected_anchor \ No newline at end of file + assert anchor == expected_anchor diff --git a/tests/test_graphs.py b/tests/test_graphs.py index b80762ed..4641d753 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -2,19 +2,17 @@ import os import json import pytest import contextvars -from unittest.mock import MagicMock -import talemate.game.engine.nodes.load_definitions -import talemate.agents.director -from talemate.context import active_scene, ActiveScene +import talemate.game.engine.nodes.load_definitions # noqa: F401 +import talemate.agents.director # noqa: F401 +from talemate.context import ActiveScene from talemate.tale_mate import Scene, Helper import talemate.instance as instance from talemate.game.engine.nodes.core import ( - Node, Graph, GraphState, GraphContext, - Socket, UNRESOLVED + Graph, + GraphState, ) from talemate.game.engine.nodes.layout import load_graph_from_file from talemate.game.engine.nodes.registry import import_talemate_node_definitions -from talemate.agents.director import DirectorAgent from talemate.client import ClientBase from collections import deque @@ -23,11 +21,13 @@ TEST_GRAPH_DIR = os.path.join(BASE_DIR, "data", "graphs") RESULTS_DIR = os.path.join(BASE_DIR, "data", "graphs", "results") UPDATE_RESULTS = False + # This runs once for the entire test session @pytest.fixture(scope="session", autouse=True) def load_node_definitions(): import_talemate_node_definitions() + def load_test_graph(name) -> Graph: path = os.path.join(TEST_GRAPH_DIR, f"{name}.json") graph, _ = load_graph_from_file(path) @@ -36,6 +36,7 @@ def load_test_graph(name) -> Graph: client_reponses = contextvars.ContextVar("client_reponses", default=deque()) + class MockClientContext: async def __aenter__(self): try: @@ -44,13 +45,14 @@ class MockClientContext: _client_reponses = deque() self.token = client_reponses.set(_client_reponses) self.client_reponses = _client_reponses - + return self.client_reponses - + async def __aexit__(self, exc_type, exc_value, traceback): if hasattr(self, "token"): client_reponses.reset(self.token) - + + class MockClient(ClientBase): def __init__(self, name: str): self.name = name @@ -58,28 +60,27 @@ class MockClient(ClientBase): self.model_name = "test-model" self.current_status = "idle" self.prompt_history = [] - - async def send_prompt(self, prompt, kind="conversation", finalize=lambda x: x, retries=2): + + async def send_prompt( + self, prompt, kind="conversation", finalize=lambda x: x, retries=2 + ): """Override send_prompt to return a pre-defined response instead of calling LLM. - + If no responses are configured, returns an empty string. Records the prompt in prompt_history for later inspection. """ - + response_stack = client_reponses.get() - - self.prompt_history.append({ - "prompt": prompt, - "kind": kind - }) - + + self.prompt_history.append({"prompt": prompt, "kind": kind}) + if not response_stack: return "" - + return response_stack.popleft() + class MockScene(Scene): - @property def auto_progress(self): """ @@ -87,12 +88,14 @@ class MockScene(Scene): """ return True + @pytest.fixture def mock_scene(): scene = MockScene() bootstrap_scene(scene) return scene + def bootstrap_scene(mock_scene): client = MockClient("test_client") director = instance.get_agent("director", client=client) @@ -105,9 +108,9 @@ def bootstrap_scene(mock_scene): mock_scene.add_helper(Helper(summarizer)) mock_scene.add_helper(Helper(editor)) mock_scene.add_helper(Helper(world_state)) - + mock_scene.mock_client = client - + return { "director": director, "conversation": conversation, @@ -116,26 +119,30 @@ def bootstrap_scene(mock_scene): "world_state": world_state, } -def make_assert_fn(name:str, write_results:bool=False): + +def make_assert_fn(name: str, write_results: bool = False): async def assert_fn(state: GraphState): - if write_results or not os.path.exists(os.path.join(RESULTS_DIR, f"{name}.json")): + if write_results or not os.path.exists( + os.path.join(RESULTS_DIR, f"{name}.json") + ): with open(os.path.join(RESULTS_DIR, f"{name}.json"), "w") as f: json.dump(state.shared, f, indent=4) else: with open(os.path.join(RESULTS_DIR, f"{name}.json"), "r") as f: expected = json.load(f) - + assert state.shared == expected - + return assert_fn -def make_graph_test(name:str, write_results:bool=False): + +def make_graph_test(name: str, write_results: bool = False): async def test_graph(scene): assert_fn = make_assert_fn(name, write_results) - + def error_handler(state, error: Exception): raise error - + with ActiveScene(scene): graph = load_test_graph(name) assert graph is not None @@ -150,32 +157,37 @@ def make_graph_test(name:str, write_results:bool=False): async def test_graph_core(mock_scene): fn = make_graph_test("test-harness-core", False) await fn(mock_scene) - + + @pytest.mark.asyncio async def test_graph_data(mock_scene): fn = make_graph_test("test-harness-data", False) await fn(mock_scene) + @pytest.mark.asyncio async def test_graph_scene(mock_scene): fn = make_graph_test("test-harness-scene", False) await fn(mock_scene) + @pytest.mark.asyncio async def test_graph_functions(mock_scene): fn = make_graph_test("test-harness-functions", False) await fn(mock_scene) + @pytest.mark.asyncio async def test_graph_agents(mock_scene): fn = make_graph_test("test-harness-agents", False) await fn(mock_scene) + @pytest.mark.asyncio async def test_graph_prompt(mock_scene): fn = make_graph_test("test-harness-prompt", False) - + async with MockClientContext() as client_reponses: client_reponses.append("The sum of 1 and 5 is 6.") client_reponses.append('```json\n{\n "result": 6\n}\n```') - await fn(mock_scene) \ No newline at end of file + await fn(mock_scene) diff --git a/tests/test_history.py b/tests/test_history.py index 3d4cd4c0..c96eea7a 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -2,7 +2,6 @@ import pytest import types from talemate.history import shift_scene_timeline -from talemate.util import iso8601_add # --------------------------------------------------------------------------- # Fixtures @@ -83,9 +82,7 @@ def dummy_scene(): ["P1002Y"], ), ], - ids=[ - "hour_plus", "hour_minus", "month_plus", "year_minus", "millennia_plus" - ], + ids=["hour_plus", "hour_minus", "month_plus", "year_minus", "millennia_plus"], ) def test_shift_scene_timeline_basic( dummy_scene, @@ -121,6 +118,7 @@ def test_shift_scene_timeline_noop(dummy_scene): ) import copy + pre_state = ( scene.ts, copy.deepcopy(scene.archived_history), @@ -131,4 +129,4 @@ def test_shift_scene_timeline_noop(dummy_scene): assert scene.ts == pre_state[0] assert scene.archived_history == pre_state[1] - assert scene.layered_history == pre_state[2] \ No newline at end of file + assert scene.layered_history == pre_state[2] diff --git a/tests/test_isodate.py b/tests/test_isodate.py index e60980f8..54c4db0c 100644 --- a/tests/test_isodate.py +++ b/tests/test_isodate.py @@ -1,24 +1,20 @@ import pytest -import isodate from talemate.util import ( iso8601_add, - iso8601_correct_duration, iso8601_diff, iso8601_diff_to_human, iso8601_duration_to_human, parse_duration_to_isodate_duration, - timedelta_to_duration, duration_to_timedelta, amount_unit_to_iso8601_duration, ) def test_isodate_utils(): - date1 = "P11MT15M" date2 = "PT1S" - duration1= parse_duration_to_isodate_duration(date1) + duration1 = parse_duration_to_isodate_duration(date1) assert duration1.months == 11 assert duration1.tdelta.seconds == 900 @@ -27,7 +23,7 @@ def test_isodate_utils(): timedelta1 = duration_to_timedelta(duration1) assert timedelta1.seconds == 900 - assert timedelta1.days == 11*30, timedelta1.days + assert timedelta1.days == 11 * 30, timedelta1.days timedelta2 = duration_to_timedelta(duration2) assert timedelta2.seconds == 1 @@ -35,121 +31,146 @@ def test_isodate_utils(): parsed = parse_duration_to_isodate_duration("P11MT14M59S") assert iso8601_diff(date1, date2) == parsed, parsed - assert iso8601_duration_to_human(date1, flatten=False) == "11 Months and 15 Minutes ago", iso8601_duration_to_human(date1, flatten=False) - assert iso8601_duration_to_human(date2, flatten=False) == "1 Second ago", iso8601_duration_to_human(date2, flatten=False) - assert iso8601_duration_to_human(iso8601_diff(date1, date2), flatten=False) == "11 Months, 14 Minutes and 59 Seconds ago", iso8601_duration_to_human(iso8601_diff(date1, date2), flatten=False) - -@pytest.mark.parametrize("dates, expected", [ - (["PT1S", "P3M", "P6M", "P8M"], "P17MT1S"), -]) -def test_adding_isodates(dates: list[str], expected: str): + assert ( + iso8601_duration_to_human(date1, flatten=False) + == "11 Months and 15 Minutes ago" + ), iso8601_duration_to_human(date1, flatten=False) + assert iso8601_duration_to_human(date2, flatten=False) == "1 Second ago", ( + iso8601_duration_to_human(date2, flatten=False) + ) + assert ( + iso8601_duration_to_human(iso8601_diff(date1, date2), flatten=False) + == "11 Months, 14 Minutes and 59 Seconds ago" + ), iso8601_duration_to_human(iso8601_diff(date1, date2), flatten=False) + +@pytest.mark.parametrize( + "dates, expected", + [ + (["PT1S", "P3M", "P6M", "P8M"], "P17MT1S"), + ], +) +def test_adding_isodates(dates: list[str], expected: str): date = dates[0] - + for i in range(1, len(dates)): date = iso8601_add(date, dates[i]) - - assert date == expected, date - - -@pytest.mark.parametrize("a, b, expected", [ - # Basic year/month cases - ("P1Y", "P11M", "1 Month and 5 Days ago"), - ("P12M", "P11M", "1 Month ago"), - ("P2Y", "P1Y", "1 Year ago"), - ("P25M", "P1Y", "1 Year, 2 Weeks and 6 Days ago"), - - # Mixed time components - ("P34DT2H30M", "PT0S", "1 Month, 4 Days, 2 Hours and 30 Minutes ago"), - ("P1YT24H", "P1Y", "1 Day ago"), - ("P1MT60S", "P1M", "1 Minute ago"), - ("P400D", "P1Y", "1 Month and 5 Days ago"), - - # Edge cases - ("PT1S", "PT0S", "1 Second ago"), - ("PT1M", "PT0S", "1 Minute ago"), - ("PT1H", "PT0S", "1 Hour ago"), - ("P1D", "PT0S", "1 Day ago"), - ("P1W", "PT0S", "1 Week ago"), - ("P1M", "PT0S", "1 Month ago"), - ("P1Y", "PT0S", "1 Year ago"), - - # Complex mixed durations - ("P1Y2M3DT4H5M6S", "PT0S", "1 Year, 2 Months, 3 Days, 4 Hours, 5 Minutes and 6 Seconds ago"), - ("P1Y1M1DT1H1M1S", "P1Y", "1 Month, 1 Day, 1 Hour, 1 Minute and 1 Second ago"), - ("P2Y15M", "P1Y", "2 Years, 2 Months, 3 Weeks and 4 Days ago"), - - # Time-only durations - ("PT24H", "PT0S", "1 Day ago"), - ("PT25H", "PT1H", "1 Day ago"), - ("PT90M", "PT30M", "1 Hour ago"), - ("PT3600S", "PT0S", "1 Hour ago"), - - # Inverse order (should give same absolute difference) - ("P1M", "P2M", "1 Month ago"), - ("PT0S", "P1Y", "1 Year ago"), - - # Zero difference - ("P1Y", "P1Y", "Recently"), - ("P1M", "P1M", "Recently"), - ("PT0S", "PT0S", "Recently"), - - # long durations - ("P0D", "P998Y23M30D", "999 Years, 11 Months, 3 Weeks and 4 Days ago"), - ("P0D", "P12M364640D", "1000 Years ago"), -]) -def test_iso8601_diff_to_human_unflattened(a, b, expected): - assert iso8601_diff_to_human(a, b, flatten=False) == expected, iso8601_diff_to_human(a, b, flatten=False) - -@pytest.mark.parametrize("a, b, expected", [ - # Basic duration flattening tests - ("P1Y2M3DT4H5M6S", "PT0S", "1 Year and 2 Months ago"), - ("P2Y7M", "PT0S", "2 Years and 7 Months ago"), - ("P18M", "PT0S", "1 Year and 6 Months ago"), - ("P6M15D", "PT0S", "6 Months ago"), - ("P45D", "PT0S", "1 Month and 15 Days ago"), - ("P25D", "PT0S", "25 Days ago"), - ("P2DT12H", "PT0S", "2 Days and 12 Hours ago"), - ("PT20H", "PT0S", "20 Hours ago"), - ("P1DT30M", "PT0S", "1 Day and 1 Hour ago"), - ("P2DT45M", "PT0S", "2 Days and 1 Hour ago"), - ("P15DT8H", "PT0S", "15 Days ago"), - ("P35DT12H30M", "PT0S", "1 Month and 5 Days ago"), - ("P12M364640D", "P0D", "1000 Years ago"), -]) -def test_iso8601_diff_to_human_flattened(a, b, expected): - assert iso8601_duration_to_human(iso8601_diff(a, b), flatten=True) == expected, \ - f"Failed for {a} vs {b}: Got {iso8601_duration_to_human(iso8601_diff(a, b), flatten=True)}" -@pytest.mark.parametrize("amount, unit, expected", [ - # Minutes - (5, "minutes", "PT5M"), - (1, "minute", "PT1M"), - # Hours - (3, "hours", "PT3H"), - # Days - (2, "days", "P2D"), - # Weeks - (2, "weeks", "P2W"), - # Months (handled specially in the date section) - (7, "months", "P7M"), - # Years - (4, "years", "P4Y"), - # Negative amount should be converted to positive duration - (-5, "hours", "PT5H"), - # 1000 years - (1000, "years", "P1000Y"), -]) + assert date == expected, date + + +@pytest.mark.parametrize( + "a, b, expected", + [ + # Basic year/month cases + ("P1Y", "P11M", "1 Month and 5 Days ago"), + ("P12M", "P11M", "1 Month ago"), + ("P2Y", "P1Y", "1 Year ago"), + ("P25M", "P1Y", "1 Year, 2 Weeks and 6 Days ago"), + # Mixed time components + ("P34DT2H30M", "PT0S", "1 Month, 4 Days, 2 Hours and 30 Minutes ago"), + ("P1YT24H", "P1Y", "1 Day ago"), + ("P1MT60S", "P1M", "1 Minute ago"), + ("P400D", "P1Y", "1 Month and 5 Days ago"), + # Edge cases + ("PT1S", "PT0S", "1 Second ago"), + ("PT1M", "PT0S", "1 Minute ago"), + ("PT1H", "PT0S", "1 Hour ago"), + ("P1D", "PT0S", "1 Day ago"), + ("P1W", "PT0S", "1 Week ago"), + ("P1M", "PT0S", "1 Month ago"), + ("P1Y", "PT0S", "1 Year ago"), + # Complex mixed durations + ( + "P1Y2M3DT4H5M6S", + "PT0S", + "1 Year, 2 Months, 3 Days, 4 Hours, 5 Minutes and 6 Seconds ago", + ), + ("P1Y1M1DT1H1M1S", "P1Y", "1 Month, 1 Day, 1 Hour, 1 Minute and 1 Second ago"), + ("P2Y15M", "P1Y", "2 Years, 2 Months, 3 Weeks and 4 Days ago"), + # Time-only durations + ("PT24H", "PT0S", "1 Day ago"), + ("PT25H", "PT1H", "1 Day ago"), + ("PT90M", "PT30M", "1 Hour ago"), + ("PT3600S", "PT0S", "1 Hour ago"), + # Inverse order (should give same absolute difference) + ("P1M", "P2M", "1 Month ago"), + ("PT0S", "P1Y", "1 Year ago"), + # Zero difference + ("P1Y", "P1Y", "Recently"), + ("P1M", "P1M", "Recently"), + ("PT0S", "PT0S", "Recently"), + # long durations + ("P0D", "P998Y23M30D", "999 Years, 11 Months, 3 Weeks and 4 Days ago"), + ("P0D", "P12M364640D", "1000 Years ago"), + ], +) +def test_iso8601_diff_to_human_unflattened(a, b, expected): + assert iso8601_diff_to_human(a, b, flatten=False) == expected, ( + iso8601_diff_to_human(a, b, flatten=False) + ) + + +@pytest.mark.parametrize( + "a, b, expected", + [ + # Basic duration flattening tests + ("P1Y2M3DT4H5M6S", "PT0S", "1 Year and 2 Months ago"), + ("P2Y7M", "PT0S", "2 Years and 7 Months ago"), + ("P18M", "PT0S", "1 Year and 6 Months ago"), + ("P6M15D", "PT0S", "6 Months ago"), + ("P45D", "PT0S", "1 Month and 15 Days ago"), + ("P25D", "PT0S", "25 Days ago"), + ("P2DT12H", "PT0S", "2 Days and 12 Hours ago"), + ("PT20H", "PT0S", "20 Hours ago"), + ("P1DT30M", "PT0S", "1 Day and 1 Hour ago"), + ("P2DT45M", "PT0S", "2 Days and 1 Hour ago"), + ("P15DT8H", "PT0S", "15 Days ago"), + ("P35DT12H30M", "PT0S", "1 Month and 5 Days ago"), + ("P12M364640D", "P0D", "1000 Years ago"), + ], +) +def test_iso8601_diff_to_human_flattened(a, b, expected): + assert iso8601_duration_to_human(iso8601_diff(a, b), flatten=True) == expected, ( + f"Failed for {a} vs {b}: Got {iso8601_duration_to_human(iso8601_diff(a, b), flatten=True)}" + ) + + +@pytest.mark.parametrize( + "amount, unit, expected", + [ + # Minutes + (5, "minutes", "PT5M"), + (1, "minute", "PT1M"), + # Hours + (3, "hours", "PT3H"), + # Days + (2, "days", "P2D"), + # Weeks + (2, "weeks", "P2W"), + # Months (handled specially in the date section) + (7, "months", "P7M"), + # Years + (4, "years", "P4Y"), + # Negative amount should be converted to positive duration + (-5, "hours", "PT5H"), + # 1000 years + (1000, "years", "P1000Y"), + ], +) def test_amount_unit_to_iso8601_duration_valid(amount: int, unit: str, expected: str): """Ensure valid (amount, unit) pairs are converted to the correct ISO-8601 duration.""" assert amount_unit_to_iso8601_duration(amount, unit) == expected -@pytest.mark.parametrize("amount, unit", [ - (1, "invalid"), - (0, "centuries"), -]) +@pytest.mark.parametrize( + "amount, unit", + [ + (1, "invalid"), + (0, "centuries"), + ], +) def test_amount_unit_to_iso8601_duration_invalid(amount: int, unit: str): """Ensure invalid units raise ValueError.""" with pytest.raises(ValueError): - amount_unit_to_iso8601_duration(amount, unit) \ No newline at end of file + amount_unit_to_iso8601_duration(amount, unit) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index d589be46..5815ea22 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,4 +1,12 @@ -from talemate.game.engine.nodes.core import Node, Graph, Socket, GraphState, Loop, Entry, Router, GraphContext +from talemate.game.engine.nodes.core import ( + Node, + Graph, + GraphState, + Loop, + Entry, + Router, + GraphContext, +) import networkx as nx import structlog import pytest @@ -6,22 +14,22 @@ from talemate.util.async_tools import cleanup_pending_tasks log = structlog.get_logger() + class Counter(Node): def __init__(self, title="Counter", **kwargs): super().__init__(title=title, **kwargs) - + def setup(self): self.add_input("state") self.add_output("value") self.set_property("counter", 0) - + async def run(self, state: GraphState): counter = self.get_property("counter") - self.set_output_values({ - "value": counter - }) + self.set_output_values({"value": counter}) self.set_property("counter", counter + 1, state) + @pytest.mark.asyncio async def test_simple_graph(): # Create nodes @@ -29,52 +37,57 @@ async def test_simple_graph(): node_b = Node(title="B") node_c = Node(title="C") node_d = Node(title="D") - + # Add sockets to nodes out_a1 = node_a.add_output("out1") out_a2 = node_a.add_output("out2") - + in_b = node_b.add_input("in") out_b = node_b.add_output("out") - + in_c = node_c.add_input("in") out_c = node_c.add_output("out") - + in_d1 = node_d.add_input("in1") in_d2 = node_d.add_input("in2") - + # Create graph graph = Graph() graph.add_node(node_a) graph.add_node(node_b) graph.add_node(node_c) graph.add_node(node_d) - + # Connect nodes via sockets - graph.connect(out_a1, in_b) # A -> B - graph.connect(out_a2, in_c) # A -> C - graph.connect(out_b, in_d1) # B -> D - graph.connect(out_c, in_d2) # C -> D - + graph.connect(out_a1, in_b) # A -> B + graph.connect(out_a2, in_c) # A -> C + graph.connect(out_b, in_d1) # B -> D + graph.connect(out_c, in_d2) # C -> D + nxgraph = graph.build() - + # Print paths - print([graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id)]) + print( + [graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id)] + ) print([graph.node(n).title for n in nx.topological_sort(nxgraph)]) - + # Add assertions for expected behavior - shortest_path = [graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id)] + shortest_path = [ + graph.node(n).title for n in nx.shortest_path(nxgraph, node_a.id, node_d.id) + ] topo_sort = [graph.node(n).title for n in nx.topological_sort(nxgraph)] - + assert len(shortest_path) == 3, "Shortest path should have 3 nodes" assert shortest_path[0] == "A", "Path should start with A" assert shortest_path[-1] == "D", "Path should end with D" assert len(topo_sort) == 4, "Should have all 4 nodes in topological sort" assert topo_sort[0] == "A", "Topological sort should start with A" assert topo_sort[-1] == "D", "Topological sort should end with D" - + await cleanup_pending_tasks() + @pytest.mark.asyncio async def test_data_flow(): # Create nodes with specific behaviors @@ -83,48 +96,41 @@ async def test_data_flow(): super().__init__(title="A") self.add_output("out1") self.add_output("out2") - + async def run(self, state: GraphState): # Output constant values for testing - self.set_output_values({ - "out1": 5, - "out2": 10 - }) + self.set_output_values({"out1": 5, "out2": 10}) class NodeB(Node): def __init__(self): super().__init__(title="B") self.add_input("in") self.add_output("out") - + async def run(self, state: GraphState): inputs = self.get_input_values() # Double the input value - self.set_output_values({ - "out": inputs["in"] * 2 - }) + self.set_output_values({"out": inputs["in"] * 2}) class NodeC(Node): def __init__(self): super().__init__(title="C") self.add_input("in") self.add_output("out") - + async def run(self, state: GraphState): inputs = self.get_input_values() # Add 1 to the input value - self.set_output_values({ - "out": inputs["in"] + 1 - }) + self.set_output_values({"out": inputs["in"] + 1}) class NodeD(Node): result: int = 0 - + def __init__(self): super().__init__(title="D") self.add_input("in1") self.add_input("in2") - + async def run(self, state: GraphState): inputs = self.get_input_values() # Store sum for testing @@ -135,36 +141,36 @@ async def test_data_flow(): node_b = NodeB() node_c = NodeC() node_d = NodeD() - + # Create graph graph = Graph() graph.add_node(node_a) graph.add_node(node_b) graph.add_node(node_c) graph.add_node(node_d) - + # Connect nodes via sockets - graph.connect(node_a.outputs[0], node_b.inputs[0]) # A.out1 -> B.in - graph.connect(node_a.outputs[1], node_c.inputs[0]) # A.out2 -> C.in - graph.connect(node_b.outputs[0], node_d.inputs[0]) # B.out -> D.in1 - graph.connect(node_c.outputs[0], node_d.inputs[1]) # C.out -> D.in2 - + graph.connect(node_a.outputs[0], node_b.inputs[0]) # A.out1 -> B.in + graph.connect(node_a.outputs[1], node_c.inputs[0]) # A.out2 -> C.in + graph.connect(node_b.outputs[0], node_d.inputs[0]) # B.out -> D.in1 + graph.connect(node_c.outputs[0], node_d.inputs[1]) # C.out -> D.in2 + async def assert_state(state: GraphState): print(state.data) # Test data flow # NodeA outputs: out1=5, out2=10 assert node_a.outputs[0].value == 5, "NodeA out1 should be 5" assert node_a.outputs[1].value == 10, "NodeA out2 should be 10" - + # NodeB doubles input: 5 * 2 = 10 assert node_b.outputs[0].value == 10, "NodeB should double input value" - + # NodeC adds 1: 10 + 1 = 11 assert node_c.outputs[0].value == 11, "NodeC should add 1 to input value" - + # NodeD sums inputs: 10 + 11 = 21 assert node_d.result == 21, "NodeD should sum its inputs" - + # Execute graph graph.callbacks.append(assert_state) await graph.execute() @@ -180,12 +186,10 @@ async def test_property_flow(): self.add_output("value") # Set default property self.set_property("value", 5) - + async def run(self, state: GraphState): # Output property value - self.set_output_values({ - "value": self.get_property("value") - }) + self.set_output_values({"value": self.get_property("value")}) class Multiplier(Node): def __init__(self): @@ -194,15 +198,15 @@ async def test_property_flow(): self.add_output("result") # Set default multiplier self.set_property("multiplier", 2) - + async def run(self, state: GraphState): inputs = self.get_input_values() - multiplier = self.get_input_value("multiplier") # Will fall back to property - + multiplier = self.get_input_value( + "multiplier" + ) # Will fall back to property + print("Multiplier input:", inputs["value"], "Multiplier:", multiplier) - self.set_output_values({ - "result": (inputs["value"] or 0) * multiplier - }) + self.set_output_values({"result": (inputs["value"] or 0) * multiplier}) class Adder(Node): def __init__(self): @@ -211,17 +215,15 @@ async def test_property_flow(): self.add_output("result") # Set default addend self.set_property("addend", 1) - + async def run(self, state: GraphState): inputs = self.get_input_values() addend = self.get_input_value("addend") # Will fall back to property - self.set_output_values({ - "result": inputs["value"] + addend - }) + self.set_output_values({"result": inputs["value"] + addend}) class Collector(Node): result: float = 0 - + def __init__(self): super().__init__(title="Collector") self.add_input("value1") @@ -229,7 +231,7 @@ async def test_property_flow(): # Set default values self.set_property("value1", 0) self.set_property("value2", 0) - + async def run(self, state: GraphState): inputs = self.get_input_values() self.result = inputs["value1"] + inputs["value2"] @@ -239,28 +241,27 @@ async def test_property_flow(): mult = Multiplier() add = Adder() collect = Collector() - + # Create graph graph = Graph() graph.add_node(source) graph.add_node(mult) graph.add_node(add) graph.add_node(collect) - + # Connect nodes - graph.connect(source.outputs[0], mult.inputs[0]) # Source -> Multiplier - graph.connect(source.outputs[0], add.inputs[0]) # Source -> Adder - graph.connect(mult.outputs[0], collect.inputs[0]) # Multiplier -> Collector.value1 - graph.connect(add.outputs[0], collect.inputs[1]) # Adder -> Collector.value2 - - - async def assert_state(state:GraphState): + graph.connect(source.outputs[0], mult.inputs[0]) # Source -> Multiplier + graph.connect(source.outputs[0], add.inputs[0]) # Source -> Adder + graph.connect(mult.outputs[0], collect.inputs[0]) # Multiplier -> Collector.value1 + graph.connect(add.outputs[0], collect.inputs[1]) # Adder -> Collector.value2 + + async def assert_state(state: GraphState): # Run assertions... assert source.outputs[0].value == 5, "Source should output property value" assert mult.outputs[0].value == 10, "Multiplier should use property multiplier" assert add.outputs[0].value == 6, "Adder should use property addend" assert collect.result == 16, "Collector should sum multiplier and adder outputs" - + # Test property defaults graph.callbacks.append(assert_state) await graph.execute() @@ -271,69 +272,76 @@ async def test_property_flow(): async def test_simple_loop(): entry_loop = Entry() counter = Counter() - + loop = Loop(exit_condition=lambda state: counter.get_property("counter") > 10) loop.add_node(entry_loop) loop.add_node(counter) loop.connect(entry_loop.outputs[0], counter.inputs[0]) - + entry = Entry() graph = Graph() - + graph.add_node(entry) graph.add_node(loop) - - graph.connect(entry.outputs[0], loop.inputs[0]) - + + graph.connect(entry.outputs[0], loop.inputs[0]) + async def assert_state(state: GraphState): assert counter.outputs[0].value == 10, "Counter should count to 10" - + loop.callbacks.append(assert_state) await graph.execute() - @pytest.mark.asyncio async def test_simple_fork(): entry = Entry(title="Entry") entry_loop = Entry(title="Entry Loop") - + counter_main = Counter("CNT Main") counter_a = Counter("CNT A") counter_b = Counter("CNT B") - router = Router(2, selector=lambda state: 0 if counter_main.get_property("counter") % 2 == 0 else 1) - - loop = Loop(title="Loop", exit_condition=lambda state: counter_main.get_property("counter") > 10) - + router = Router( + 2, + selector=lambda state: 0 + if counter_main.get_property("counter") % 2 == 0 + else 1, + ) + + loop = Loop( + title="Loop", + exit_condition=lambda state: counter_main.get_property("counter") > 10, + ) + loop.add_node(entry_loop) loop.add_node(counter_main) loop.add_node(counter_a) loop.add_node(counter_b) loop.add_node(router) - + loop.connect(entry_loop.outputs[0], counter_main.inputs[0]) loop.connect(counter_main.outputs[0], router.inputs[0]) loop.connect(router.outputs[0], counter_a.inputs[0]) loop.connect(router.outputs[1], counter_b.inputs[0]) - + graph = Graph() graph.add_node(entry) graph.add_node(loop) - + graph.connect(entry.outputs[0], loop.inputs[0]) - - async def assert_state_loop(state: GraphState): - assert counter_main.get_property("counter") == 11, "Main counter should count to 11" + + async def assert_state_loop(state: GraphState): + assert counter_main.get_property("counter") == 11, ( + "Main counter should count to 11" + ) assert counter_a.get_property("counter") == 5, "Counter A should count to 5" assert counter_b.get_property("counter") == 5, "Counter B should count to 5" - + loop.callbacks.append(assert_state_loop) await graph.execute() await cleanup_pending_tasks() - - @pytest.mark.asyncio async def test_visited_paths(): @@ -342,19 +350,19 @@ async def test_visited_paths(): # A -> B # A -> C -> B # Only one path gets deactivated, other should still work - + graph = Graph() - + # Create nodes node_a = Node(title="Node A") node_b = Node(title="Node B") node_c = Node(title="Node C") - + # Add nodes to graph graph.add_node(node_a) graph.add_node(node_b) graph.add_node(node_c) - + # Create sockets a_out1 = node_a.add_output("out1") a_out2 = node_a.add_output("out2") @@ -362,26 +370,29 @@ async def test_visited_paths(): b_in2 = node_b.add_input("in2") c_in = node_c.add_input("in") c_out = node_c.add_output("out") - + # Connect nodes # A -> B (direct path) graph.connect(a_out1, b_in1) # A -> C -> B (indirect path) graph.connect(a_out2, c_in) graph.connect(c_out, b_in2) - + with GraphContext() as state: # Deactivate the direct path a_out1.deactivated = True - + # Node A should still be available because the path through C is still active - assert node_a.check_is_available(state), "Node A should be available through path via C" - + assert node_a.check_is_available(state), ( + "Node A should be available through path via C" + ) + # Now deactivate the indirect path too a_out2.deactivated = True - + # Now Node A should be unavailable as all paths are deactivated - assert not node_a.check_is_available(state), "Node A should be unavailable when all paths are deactivated" - + assert not node_a.check_is_available(state), ( + "Node A should be unavailable when all paths are deactivated" + ) + await cleanup_pending_tasks() - \ No newline at end of file diff --git a/tests/test_strip_partial_sentences.py b/tests/test_strip_partial_sentences.py index 267e8c20..0439be2f 100644 --- a/tests/test_strip_partial_sentences.py +++ b/tests/test_strip_partial_sentences.py @@ -2,16 +2,21 @@ import pytest from talemate.util import strip_partial_sentences -@pytest.mark.parametrize("input, expected", [ - ("This is a test{delim} This is a test{delim}", "This is a test{delim} This is a test{delim}"), - ("This is a test{delim} This is a test", "This is a test{delim}"), - ("This is a test{delim}\nThis is a test", "This is a test{delim}"), -]) +@pytest.mark.parametrize( + "input, expected", + [ + ( + "This is a test{delim} This is a test{delim}", + "This is a test{delim} This is a test{delim}", + ), + ("This is a test{delim} This is a test", "This is a test{delim}"), + ("This is a test{delim}\nThis is a test", "This is a test{delim}"), + ], +) def test_strip_partial_sentences(input, expected): - delimiters = [".", "!", "?", '"', "*"] - + for delim in delimiters: input = input.format(delim=delim) expected = expected.format(delim=delim) - assert strip_partial_sentences(input) == expected \ No newline at end of file + assert strip_partial_sentences(input) == expected diff --git a/tests/test_system_messages.py b/tests/test_system_messages.py index 0aed9a82..10f86396 100644 --- a/tests/test_system_messages.py +++ b/tests/test_system_messages.py @@ -1,6 +1,7 @@ import pytest from talemate.client.base import ClientBase + @pytest.mark.parametrize( "kind", [ @@ -18,16 +19,14 @@ from talemate.client.base import ClientBase ], ) def test_system_message(kind): - client = ClientBase() - + assert client.get_system_message(kind) is not None - + assert "explicit" in client.get_system_message(kind) - + client.decensor_enabled = False - + assert client.get_system_message(kind) is not None - + assert "explicit" not in client.get_system_message(kind) - \ No newline at end of file diff --git a/tests/test_utils_data.py b/tests/test_utils_data.py index 9adf84ef..9e1243ba 100644 --- a/tests/test_utils_data.py +++ b/tests/test_utils_data.py @@ -10,42 +10,46 @@ from talemate.util.data import ( JSONEncoder, DataParsingError, fix_yaml_colon_in_strings, - fix_faulty_yaml + fix_faulty_yaml, ) + # Helper function to get test data paths def get_test_data_path(filename): base_dir = os.path.dirname(os.path.abspath(__file__)) - return os.path.join(base_dir, 'data', 'util', 'data', filename) + return os.path.join(base_dir, "data", "util", "data", filename) + def test_json_encoder(): """Test JSONEncoder handles unknown types by converting to string.""" + class CustomObject: def __str__(self): return "CustomObject" - + # Create an object of a custom class custom_obj = CustomObject() - + # Encode it using JSONEncoder encoded = json.dumps({"obj": custom_obj}, cls=JSONEncoder) - + # Check if the object was converted to a string assert encoded == '{"obj": "CustomObject"}' + def test_fix_faulty_json(): """Test fix_faulty_json function with various faulty JSON strings.""" - + # Test adjacent objects - need to wrap in list brackets to make it valid JSON fixed = fix_faulty_json('{"a": 1}{"b": 2}') assert fixed == '{"a": 1},{"b": 2}' # We need to manually wrap it in brackets for the test - assert json.loads('[' + fixed + ']') == [{"a": 1}, {"b": 2}] - + assert json.loads("[" + fixed + "]") == [{"a": 1}, {"b": 2}] + # Test trailing commas assert json.loads(fix_faulty_json('{"a": 1, "b": 2,}')) == {"a": 1, "b": 2} assert json.loads(fix_faulty_json('{"a": [1, 2, 3,]}')) == {"a": [1, 2, 3]} - + def test_extract_json(): """Test extract_json function to extract JSON from the beginning of a string.""" @@ -53,67 +57,64 @@ def test_extract_json(): json_str, obj = extract_json('{"name": "test", "value": 42} and some text') assert json_str == '{"name": "test", "value": 42}' assert obj == {"name": "test", "value": 42} - + # Test with array - json_str, obj = extract_json('[1, 2, 3] and some text') - assert json_str == '[1, 2, 3]' + json_str, obj = extract_json("[1, 2, 3] and some text") + assert json_str == "[1, 2, 3]" assert obj == [1, 2, 3] - + # Test with whitespace json_str, obj = extract_json(' {"name": "test"} and some text') assert json_str == '{"name": "test"}' assert obj == {"name": "test"} - + # Test with invalid JSON with pytest.raises(ValueError): - extract_json('This is not JSON') + extract_json("This is not JSON") + def test_extract_json_v2_valid(): """Test extract_json_v2 with valid JSON in code blocks.""" # Load test data - with open(get_test_data_path('valid_json.txt'), 'r') as f: + with open(get_test_data_path("valid_json.txt"), "r") as f: text = f.read() - + # Extract JSON result = extract_json_v2(text) - + # Check if we got two unique JSON objects (third is a duplicate) assert len(result) == 2 - + # Check if the objects are correct expected_first = { "name": "Test Object", - "properties": { - "id": 1, - "active": True - }, - "tags": ["test", "json", "parsing"] + "properties": {"id": 1, "active": True}, + "tags": ["test", "json", "parsing"], } - - expected_second = { - "name": "Simple Object", - "value": 42 - } - + + expected_second = {"name": "Simple Object", "value": 42} + assert expected_first in result assert expected_second in result + def test_extract_json_v2_invalid(): """Test extract_json_v2 raises DataParsingError for invalid JSON.""" # Load test data - with open(get_test_data_path('invalid_json.txt'), 'r') as f: + with open(get_test_data_path("invalid_json.txt"), "r") as f: text = f.read() - + # Try to extract JSON, should raise DataParsingError with pytest.raises(DataParsingError): extract_json_v2(text) + def test_extract_json_v2_faulty(): """Test extract_json_v2 with faulty but fixable JSON.""" # Load test data - with open(get_test_data_path('faulty_json.txt'), 'r') as f: + with open(get_test_data_path("faulty_json.txt"), "r") as f: text = f.read() - + # Try to extract JSON, should successfully fix and extract some objects # but might fail on the severely malformed ones try: @@ -124,258 +125,258 @@ def test_extract_json_v2_faulty(): # This is also acceptable if some JSON is too broken to fix pass + def test_data_parsing_error(): """Test the DataParsingError class.""" # Create a DataParsingError with a message and data test_data = '{"broken": "json"' error = DataParsingError("Test error message", test_data) - + # Check properties assert error.message == "Test error message" assert error.data == test_data assert str(error) == "Test error message" + def test_extract_json_v2_multiple(): """Test extract_json_v2 with multiple JSON objects including duplicates.""" # Load test data - with open(get_test_data_path('multiple_json.txt'), 'r') as f: + with open(get_test_data_path("multiple_json.txt"), "r") as f: text = f.read() - + # Extract JSON result = extract_json_v2(text) - + # Check if we got the correct number of unique objects (3 unique out of 5 total) assert len(result) == 3 - + # Define expected objects expected_objects = [ - { - "id": 1, - "name": "First Object", - "tags": ["one", "first", "primary"] - }, - { - "id": 2, - "name": "Second Object", - "tags": ["two", "second"] - }, + {"id": 1, "name": "First Object", "tags": ["one", "first", "primary"]}, + {"id": 2, "name": "Second Object", "tags": ["two", "second"]}, { "id": 3, "name": "Third Object", - "metadata": { - "created": "2023-01-01", - "version": 1.0 - }, - "active": True - } + "metadata": {"created": "2023-01-01", "version": 1.0}, + "active": True, + }, ] - + # Check if all expected objects are in the result for expected in expected_objects: assert expected in result - + # Verify that each object appears exactly once (no duplicates) id_counts = {} for obj in result: id_counts[obj["id"]] = id_counts.get(obj["id"], 0) + 1 - + # Each ID should appear exactly once for id_val, count in id_counts.items(): - assert count == 1, f"Object with ID {id_val} appears {count} times (should be 1)" + assert count == 1, ( + f"Object with ID {id_val} appears {count} times (should be 1)" + ) + def test_extract_yaml_v2_valid(): """Test extract_yaml_v2 with valid YAML in code blocks.""" # Load test data - with open(get_test_data_path('valid_yaml.txt'), 'r') as f: + with open(get_test_data_path("valid_yaml.txt"), "r") as f: text = f.read() - + # Extract YAML result = extract_yaml_v2(text) - + # Check if we got two unique YAML objects (third is a duplicate) assert len(result) == 2 - + # Check if the objects are correct expected_first = { "name": "Test Object", - "properties": { - "id": 1, - "active": True - }, - "tags": ["test", "yaml", "parsing"] + "properties": {"id": 1, "active": True}, + "tags": ["test", "yaml", "parsing"], } - - expected_second = { - "simple_name": "Simple Object", - "value": 42 - } - + + expected_second = {"simple_name": "Simple Object", "value": 42} + assert expected_first in result assert expected_second in result + def test_extract_yaml_v2_invalid(): """Test extract_yaml_v2 raises DataParsingError for invalid YAML.""" # Load test data - with open(get_test_data_path('invalid_yaml.txt'), 'r') as f: + with open(get_test_data_path("invalid_yaml.txt"), "r") as f: text = f.read() - + # Try to extract YAML, should raise DataParsingError with pytest.raises(DataParsingError): extract_yaml_v2(text) + def test_extract_yaml_v2_multiple(): """Test extract_yaml_v2 with multiple YAML objects including duplicates.""" # Load test data - with open(get_test_data_path('multiple_yaml.txt'), 'r') as f: + with open(get_test_data_path("multiple_yaml.txt"), "r") as f: text = f.read() - + # Extract YAML result = extract_yaml_v2(text) - + # Check if we got the correct number of unique objects (3 unique out of 5 total) assert len(result) == 3 - + # Get the objects by ID for easier assertions objects_by_id = {obj["id"]: obj for obj in result} - + # Check for object 1 assert objects_by_id[1]["name"] == "First Object" assert objects_by_id[1]["tags"] == ["one", "first", "primary"] - + # Check for object 2 assert objects_by_id[2]["name"] == "Second Object" assert objects_by_id[2]["tags"] == ["two", "second"] - + # Check for object 3 - note that the date is parsed as a date object by YAML assert objects_by_id[3]["name"] == "Third Object" assert objects_by_id[3]["active"] is True assert "created" in objects_by_id[3]["metadata"] - + # Verify that each object ID appears exactly once (no duplicates) id_counts = {} for obj in result: id_counts[obj["id"]] = id_counts.get(obj["id"], 0) + 1 - + # Each ID should appear exactly once for id_val, count in id_counts.items(): - assert count == 1, f"Object with ID {id_val} appears {count} times (should be 1)" + assert count == 1, ( + f"Object with ID {id_val} appears {count} times (should be 1)" + ) + def test_extract_yaml_v2_multiple_documents(): """Test extract_yaml_v2 with multiple YAML documents in a single code block.""" # Load test data from file - with open(get_test_data_path('multiple_yaml_documents.txt'), 'r') as f: + with open(get_test_data_path("multiple_yaml_documents.txt"), "r") as f: test_data = f.read() - + # Extract YAML result = extract_yaml_v2(test_data) - + # Check if we got all three documents assert len(result) == 3 - + # Check if the objects are correct objects_by_id = {obj["id"]: obj for obj in result} - + assert objects_by_id[1]["name"] == "First Document" assert "first" in objects_by_id[1]["tags"] - + assert objects_by_id[2]["name"] == "Second Document" assert "secondary" in objects_by_id[2]["tags"] - + assert objects_by_id[3]["name"] == "Third Document" assert objects_by_id[3]["active"] is True + def test_extract_yaml_v2_without_separators(): """Test extract_yaml_v2 with multiple YAML documents without --- separators.""" # Load test data from file - with open(get_test_data_path('multiple_yaml_without_separators.txt'), 'r') as f: + with open(get_test_data_path("multiple_yaml_without_separators.txt"), "r") as f: test_data = f.read() - + # Extract YAML result = extract_yaml_v2(test_data) - + # Check if we got all three nested documents assert len(result) == 3 - + # Create a dictionary of documents by name for easy testing docs_by_name = {doc["name"]: doc for doc in result} - + # Verify that all three documents are correctly parsed assert "First Document" in docs_by_name assert docs_by_name["First Document"]["id"] == 1 assert "first" in docs_by_name["First Document"]["tags"] - + assert "Second Document" in docs_by_name assert docs_by_name["Second Document"]["id"] == 2 assert "secondary" in docs_by_name["Second Document"]["tags"] - + assert "Third Document" in docs_by_name assert docs_by_name["Third Document"]["id"] == 3 assert docs_by_name["Third Document"]["active"] is True + def test_extract_json_v2_multiple_objects(): """Test extract_json_v2 with multiple JSON objects in a single code block.""" # Load test data from file - with open(get_test_data_path('multiple_json_objects.txt'), 'r') as f: + with open(get_test_data_path("multiple_json_objects.txt"), "r") as f: test_data = f.read() - + # Extract JSON result = extract_json_v2(test_data) - + # Check if we got all three objects assert len(result) == 3 - + # Check if the objects are correct objects_by_id = {obj["id"]: obj for obj in result} - + assert objects_by_id[1]["name"] == "First Object" assert objects_by_id[1]["type"] == "test" - + assert objects_by_id[2]["name"] == "Second Object" assert objects_by_id[2]["values"] == [1, 2, 3] - + assert objects_by_id[3]["name"] == "Third Object" assert objects_by_id[3]["active"] is True assert objects_by_id[3]["metadata"]["created"] == "2023-05-15" + def test_fix_yaml_colon_in_strings(): """Test fix_yaml_colon_in_strings with problematic YAML containing unquoted colons.""" # Load test data from file - with open(get_test_data_path('yaml_with_colons.txt'), 'r') as f: + with open(get_test_data_path("yaml_with_colons.txt"), "r") as f: problematic_yaml = f.read() - + # Extract YAML from the code block problematic_yaml = problematic_yaml.split("```")[1] if problematic_yaml.startswith("yaml"): problematic_yaml = problematic_yaml[4:].strip() - + # Fix the YAML fixed_yaml = fix_yaml_colon_in_strings(problematic_yaml) - + # Parse the fixed YAML to check it works parsed = yaml.safe_load(fixed_yaml) - + # Check the structure and content is preserved assert parsed["calls"][0]["name"] == "act" assert parsed["calls"][0]["arguments"]["name"] == "Kaira" - assert "I can see you're scared, Elmer" in parsed["calls"][0]["arguments"]["instructions"] + assert ( + "I can see you're scared, Elmer" + in parsed["calls"][0]["arguments"]["instructions"] + ) + def test_fix_faulty_yaml(): """Test fix_faulty_yaml with various problematic YAML constructs.""" # Load test data from file - with open(get_test_data_path('yaml_list_with_colons.txt'), 'r') as f: + with open(get_test_data_path("yaml_list_with_colons.txt"), "r") as f: problematic_yaml = f.read() - + # Extract YAML from the code block problematic_yaml = problematic_yaml.split("```")[1] if problematic_yaml.startswith("yaml"): problematic_yaml = problematic_yaml[4:].strip() - + # Fix the YAML fixed_yaml = fix_faulty_yaml(problematic_yaml) - + # Parse the fixed YAML to check it works parsed = yaml.safe_load(fixed_yaml) - + # Check the structure and content is preserved assert len(parsed["instructions_list"]) == 2 # The content will be the full string with colons in it now @@ -384,43 +385,44 @@ def test_fix_faulty_yaml(): assert "Look around" in parsed["instructions_list"][1] assert "Is there another way out?" in parsed["instructions_list"][1] + def test_extract_yaml_v2_with_colons(): """Test extract_yaml_v2 correctly processes YAML with problematic colons in strings.""" # Load test data containing YAML code blocks with problematic colons - with open(get_test_data_path('yaml_block_with_colons.txt'), 'r') as f: + with open(get_test_data_path("yaml_block_with_colons.txt"), "r") as f: text = f.read() - + # Extract YAML result = extract_yaml_v2(text) - + # Check if we got the two YAML objects assert len(result) == 2 - + # Find the objects by their structure calls_obj = None instructions_obj = None for obj in result: - if 'calls' in obj: + if "calls" in obj: calls_obj = obj - elif 'instructions_list' in obj: + elif "instructions_list" in obj: instructions_obj = obj - + # Verify both objects were found assert calls_obj is not None, "Could not find the 'calls' object" assert instructions_obj is not None, "Could not find the 'instructions_list' object" - + # Check the structure and content of the first object (calls) assert calls_obj["calls"][0]["name"] == "act" assert calls_obj["calls"][0]["arguments"]["name"] == "Kaira" - + # Check that the problematic part with the colon is preserved instructions = calls_obj["calls"][0]["arguments"]["instructions"] assert "Speak in a calm, soothing tone and say:" in instructions assert "I can see you're scared, Elmer" in instructions - + # Check the second object (instructions_list) assert len(instructions_obj["instructions_list"]) == 2 assert "Run to the door" in instructions_obj["instructions_list"][0] assert "Wait for me!" in instructions_obj["instructions_list"][0] assert "Look around" in instructions_obj["instructions_list"][1] - assert "Is there another way out?" in instructions_obj["instructions_list"][1] \ No newline at end of file + assert "Is there another way out?" in instructions_obj["instructions_list"][1] diff --git a/uv.lock b/uv.lock index 2fa86b96..4bbf3d11 100644 --- a/uv.lock +++ b/uv.lock @@ -592,6 +592,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009, upload-time = "2024-09-04T20:44:45.309Z" }, ] +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114, upload-time = "2023-08-12T20:38:17.776Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249, upload-time = "2023-08-12T20:38:16.269Z" }, +] + [[package]] name = "charset-normalizer" version = "3.4.2" @@ -851,6 +860,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" }, ] +[[package]] +name = "distlib" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923, upload-time = "2024-10-09T18:35:47.551Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973, upload-time = "2024-10-09T18:35:44.272Z" }, +] + [[package]] name = "distro" version = "1.9.0" @@ -1374,6 +1392,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794, upload-time = "2021-09-17T21:40:39.897Z" }, ] +[[package]] +name = "identify" +version = "2.6.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/88/d193a27416618628a5eea64e3223acd800b40749a96ffb322a9b55a49ed1/identify-2.6.12.tar.gz", hash = "sha256:d8de45749f1efb108badef65ee8386f0f7bb19a7f26185f74de6367bffbaf0e6", size = 99254, upload-time = "2025-05-23T20:37:53.3Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/cd/18f8da995b658420625f7ef13f037be53ae04ec5ad33f9b718240dcfd48c/identify-2.6.12-py2.py3-none-any.whl", hash = "sha256:ad9672d5a72e0d2ff7c5c8809b62dfa60458626352fb0eb7b55e69bdc45334a2", size = 99145, upload-time = "2025-05-23T20:37:51.495Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -2158,6 +2185,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442, upload-time = "2024-08-18T19:48:21.909Z" }, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, +] + [[package]] name = "numpy" version = "2.2.6" @@ -2879,6 +2915,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/98/e480cab9a08d1c09b1c59a93dade92c1bb7544826684ff2acbfd10fcfbd4/posthog-5.4.0-py3-none-any.whl", hash = "sha256:284dfa302f64353484420b52d4ad81ff5c2c2d1d607c4e2db602ac72761831bd", size = 105364, upload-time = "2025-06-20T23:19:22.001Z" }, ] +[[package]] +name = "pre-commit" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/39/679ca9b26c7bb2999ff122d50faa301e49af82ca9c066ec061cfbc0c6784/pre_commit-4.2.0.tar.gz", hash = "sha256:601283b9757afd87d40c4c4a9b2b5de9637a8ea02eaff7adc2d0fb4e04841146", size = 193424, upload-time = "2025-03-18T21:35:20.987Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/74/a88bf1b1efeae488a0c0b7bdf71429c313722d1fc0f377537fbe554e6180/pre_commit-4.2.0-py2.py3-none-any.whl", hash = "sha256:a009ca7205f1eb497d10b845e52c838a98b6cdd2102a6c8e4540e94ee75c58bd", size = 220707, upload-time = "2025-03-18T21:35:19.343Z" }, +] + [[package]] name = "prettytable" version = "3.16.0" @@ -4346,6 +4398,7 @@ dev = [ { name = "mkdocs-glightbox" }, { name = "mkdocs-material" }, { name = "mypy" }, + { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-asyncio" }, ] @@ -4384,6 +4437,7 @@ requires-dist = [ { name = "openai", specifier = ">=1" }, { name = "piexif", specifier = ">=1.1" }, { name = "pillow", specifier = ">=9.5" }, + { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=2.13" }, { name = "pydantic", specifier = "<3" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=6.2" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.25.3" }, @@ -4878,6 +4932,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018, upload-time = "2024-10-14T23:38:10.888Z" }, ] +[[package]] +name = "virtualenv" +version = "20.31.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/56/2c/444f465fb2c65f40c3a104fd0c495184c4f2336d65baf398e3c75d72ea94/virtualenv-20.31.2.tar.gz", hash = "sha256:e10c0a9d02835e592521be48b332b6caee6887f332c111aa79a09b9e79efc2af", size = 6076316, upload-time = "2025-05-08T17:58:23.811Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/40/b1c265d4b2b62b58576588510fc4d1fe60a86319c8de99fd8e9fec617d2c/virtualenv-20.31.2-py3-none-any.whl", hash = "sha256:36efd0d9650ee985f0cad72065001e66d49a6f24eb44d98980f630686243cf11", size = 6057982, upload-time = "2025-05-08T17:58:21.15Z" }, +] + [[package]] name = "watchdog" version = "6.0.0"