mirror of
https://github.com/vegu-ai/talemate.git
synced 2026-05-18 05:05:39 +02:00
0.37.0 - **Director Planning** — Multi-step todo lists in director chat plus a Generate long progress action for multi-beat scene arcs. - **Auto Narration** — Unified auto-narration replacing the old Narrate after Dialogue toggle, with a chance slider and weighted action mix. - **LLM Prompt Templates Manager** — Dedicated UI tab for viewing, creating, editing, and deleting prompt templates. - **Character Folders** — Collapsible folders in the World Editor character list, synced across linked scenes. - **OpenAI Compatible TTS** — Connect any number of OpenAI-compatible TTS servers in parallel. - **KoboldCpp TTS Auto-Setup** — KoboldCpp clients with a TTS model loaded register themselves as a TTS backend. - **Model Testing Harness** — Bundled scene that runs basic capability tests against any connected LLM. Plus 27 improvements and 28 bug fixes
780 lines
24 KiB
Python
780 lines
24 KiB
Python
"""
|
|
Unit tests for src/talemate/game/engine/nodes/run.py.
|
|
|
|
Covers function/graph plumbing nodes:
|
|
- title_to_function_name
|
|
- BreakpointEvent (dataclass shape)
|
|
- FunctionArgument._convert_value (many type conversions)
|
|
- FunctionArgument.run (resolving argument value from state.data)
|
|
- FunctionReturn.run (sets __fn_result, raises StopGraphExecution)
|
|
- DefineFunction (isolated, never_run) and GetFunction lookup
|
|
- FunctionWrapper.__call__ via Define+Get+Return-style minimal graphs
|
|
- CallFunction with a FunctionWrapper
|
|
- CallForEach with list/dict items
|
|
- RunModule: success, infinite-loop guard, error wrapping
|
|
- Breakpoint: skipped in non-creative environment, skipped when active=False
|
|
- UnpackException with ExceptionWrapper
|
|
- ErrorHandler.catch shape
|
|
|
|
Skipped paths:
|
|
- Function.run (graph-level) creates an `ai_callback` via `wrapped.ai_callback`,
|
|
which scans for a Metadata node and calls `normalized_input_value` — covered
|
|
via the FunctionWrapper.ai_callback test that uses a real DefineFunction
|
|
endpoint.
|
|
- Breakpoint's wait-for-release loop: tested only the short-circuit paths;
|
|
the wait-loop polls `state.shared['__breakpoint']` and signals are exercised
|
|
in higher-level integration tests.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from talemate.game.engine.nodes.core import (
|
|
Entry,
|
|
Graph,
|
|
GraphState,
|
|
Node,
|
|
StopGraphExecution,
|
|
UNRESOLVED,
|
|
)
|
|
from talemate.game.engine.nodes.core.exception import ExceptionWrapper
|
|
from talemate.game.engine.nodes.run import (
|
|
Breakpoint,
|
|
BreakpointEvent,
|
|
CallForEach,
|
|
CallFunction,
|
|
DefineFunction,
|
|
ErrorHandler,
|
|
Function,
|
|
FunctionArgument,
|
|
FunctionReturn,
|
|
FunctionWrapper,
|
|
GetFunction,
|
|
RunModule,
|
|
UnpackException,
|
|
title_to_function_name,
|
|
)
|
|
from talemate.util.async_tools import cleanup_pending_tasks
|
|
|
|
from conftest import MockScene
|
|
from _node_test_helpers import run_node
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# title_to_function_name
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_title_to_function_name_replaces_special_chars():
|
|
assert title_to_function_name("My Function!") == "My_Function_"
|
|
|
|
|
|
def test_title_to_function_name_keeps_alphanumeric_and_underscores():
|
|
assert title_to_function_name("foo_bar123") == "foo_bar123"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# BreakpointEvent
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_breakpoint_event_holds_node_and_state():
|
|
n = Breakpoint()
|
|
state = GraphState()
|
|
evt = BreakpointEvent(node=n, state=state, module_path="some/path")
|
|
assert evt.node is n
|
|
assert evt.state is state
|
|
assert evt.module_path == "some/path"
|
|
|
|
|
|
def test_breakpoint_event_module_path_defaults_to_none():
|
|
evt = BreakpointEvent(node=Breakpoint(), state=GraphState())
|
|
assert evt.module_path is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FunctionArgument._convert_value
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"typ,value,expected",
|
|
[
|
|
("any", "anything", "anything"),
|
|
("any", 42, 42),
|
|
("str", 5, "5"),
|
|
("str", True, "True"),
|
|
("int", "7", 7),
|
|
("int", None, None),
|
|
("float", "3.14", 3.14),
|
|
("float", None, None),
|
|
("bool", "true", True),
|
|
("bool", "yes", True),
|
|
("bool", "1", True),
|
|
("bool", "false", False),
|
|
("bool", "no", False),
|
|
("bool", "0", False),
|
|
("bool", "anything else", True), # non-empty string -> True
|
|
("bool", 1, True),
|
|
("bool", 0, False),
|
|
("list", ["a", "b"], ["a", "b"]),
|
|
("list", "a\nb\nc", ["a", "b", "c"]),
|
|
],
|
|
)
|
|
def test_function_argument_convert_value(typ, value, expected):
|
|
arg = FunctionArgument()
|
|
arg.set_property("typ", typ)
|
|
assert arg._convert_value(value) == expected
|
|
|
|
|
|
def test_function_argument_str_passes_through_non_scalar():
|
|
"""`typ=str` only stringifies scalar primitives; complex objects pass
|
|
through unchanged."""
|
|
arg = FunctionArgument()
|
|
arg.set_property("typ", "str")
|
|
obj = {"a": 1}
|
|
assert arg._convert_value(obj) is obj
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FunctionArgument.run
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_argument_resolves_value_from_state_data():
|
|
"""run() reads `state.data[f'{node.id}__fn_arg_value']` and converts it."""
|
|
arg = FunctionArgument()
|
|
arg.set_property("name", "x")
|
|
arg.set_property("typ", "int")
|
|
|
|
def setup(state):
|
|
state.data[f"{arg.id}__fn_arg_value"] = "12"
|
|
|
|
out = await run_node(arg, state_setup=setup)
|
|
assert out["value"] == 12
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_argument_unresolved_when_no_state_value():
|
|
"""When the bound state value is missing, run() emits UNRESOLVED converted
|
|
by typ (any -> UNRESOLVED)."""
|
|
arg = FunctionArgument()
|
|
arg.set_property("name", "x")
|
|
arg.set_property("typ", "any")
|
|
|
|
out = await run_node(arg)
|
|
assert out["value"] is UNRESOLVED
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FunctionReturn
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_return_sets_fn_result_and_raises_stop():
|
|
ret = FunctionReturn()
|
|
|
|
with pytest.raises(StopGraphExecution):
|
|
await run_node(ret, inputs={"value": "ok"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_return_unresolved_input_does_not_raise():
|
|
"""If `value` is UNRESOLVED, FunctionReturn returns silently (no
|
|
StopGraphExecution, no __fn_result)."""
|
|
ret = FunctionReturn()
|
|
ret.properties["value"] = UNRESOLVED
|
|
out = await run_node(ret)
|
|
assert out["value"] is UNRESOLVED
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DefineFunction / GetFunction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_define_function_run_is_noop_and_marks_isolated():
|
|
define = DefineFunction()
|
|
# `_isolated` blocks the node from running in normal scheduling
|
|
assert DefineFunction._isolated is True
|
|
# never_run is True so dispatcher won't reach this
|
|
assert define.never_run is True
|
|
# run still works (returns None)
|
|
state = GraphState()
|
|
result = await define.run(state)
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_function_returns_wrapper_when_define_function_exists():
|
|
"""A graph with a DefineFunction node (matching name) and a GetFunction
|
|
node will produce a FunctionWrapper from `GetFunction.run`."""
|
|
define = DefineFunction()
|
|
define.set_property("name", "my_fn")
|
|
|
|
# DefineFunction needs a `nodes` source — wire any node's output into it
|
|
src = Entry(title="src")
|
|
graph = Graph()
|
|
graph.add_node(src)
|
|
graph.add_node(define)
|
|
graph.connect(src.outputs[0], define.get_input_socket("nodes"))
|
|
|
|
# GetFunction is at module-level; we don't need to add it to the graph
|
|
# for this test — `state.graph` is what matters
|
|
get_fn = GetFunction()
|
|
get_fn.set_property("name", "my_fn")
|
|
|
|
state = GraphState()
|
|
state.graph = graph
|
|
|
|
wrapper = await get_fn.get_function(graph, state)
|
|
assert isinstance(wrapper, FunctionWrapper)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_function_raises_when_not_found():
|
|
"""Without a matching DefineFunction node, run() raises ValueError."""
|
|
graph = Graph()
|
|
state = GraphState()
|
|
state.graph = graph
|
|
get_fn = GetFunction()
|
|
get_fn.set_property("name", "nope")
|
|
with pytest.raises(ValueError):
|
|
await get_fn.run(state)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FunctionWrapper.__call__ end-to-end via DefineFunction with subgraph endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_function_graph(ret_value: int):
|
|
"""Build a graph that defines an inline function which simply returns
|
|
`ret_value` immediately. Returns (graph, define_node)."""
|
|
|
|
class _Constant(Node):
|
|
def __init__(self, **kw):
|
|
super().__init__(title="Constant", **kw)
|
|
|
|
def setup(self):
|
|
self.add_output("value")
|
|
|
|
async def run(self, state):
|
|
self.set_output_values({"value": ret_value})
|
|
|
|
const = _Constant()
|
|
fn_return = FunctionReturn()
|
|
define = DefineFunction()
|
|
define.set_property("name", "the_fn")
|
|
|
|
graph = Graph()
|
|
graph.add_node(const)
|
|
graph.add_node(fn_return)
|
|
graph.add_node(define)
|
|
graph.connect(const.outputs[0], fn_return.get_input_socket("value"))
|
|
graph.connect(
|
|
fn_return.get_output_socket("value"),
|
|
define.get_input_socket("nodes"),
|
|
)
|
|
return graph, define
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_wrapper_call_returns_function_return_value():
|
|
graph, define = _build_function_graph(42)
|
|
state = GraphState()
|
|
state.graph = graph
|
|
|
|
wrapper = await define.get_function(state)
|
|
result = await wrapper()
|
|
assert result == 42
|
|
|
|
await cleanup_pending_tasks()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_wrapper_get_argument_nodes_filters_by_type():
|
|
graph, define = _build_function_graph(1)
|
|
arg = FunctionArgument()
|
|
arg.set_property("name", "arg1")
|
|
graph.add_node(arg)
|
|
# Wire arg into FunctionReturn so it shows up as a connected ancestor
|
|
fn_return = next(n for n in graph.nodes.values() if isinstance(n, FunctionReturn))
|
|
# There is already a connection from constant; replace by routing through arg
|
|
# We won't actually call the function; we just want get_argument_nodes to find arg
|
|
graph.connect(arg.get_output_socket("value"), fn_return.get_input_socket("value"))
|
|
|
|
state = GraphState()
|
|
state.graph = graph
|
|
wrapper = await define.get_function(state)
|
|
arg_nodes = await wrapper.get_argument_nodes()
|
|
assert len(arg_nodes) == 1
|
|
assert arg_nodes[0] is arg
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CallFunction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_function_invokes_wrapper_and_returns_result():
|
|
graph, define = _build_function_graph(7)
|
|
state = GraphState()
|
|
state.graph = graph
|
|
wrapper = await define.get_function(state)
|
|
|
|
node = CallFunction()
|
|
out = await run_node(node, inputs={"fn": wrapper, "args": {}})
|
|
assert out["result"] == 7
|
|
|
|
await cleanup_pending_tasks()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_function_rejects_non_function_input():
|
|
node = CallFunction()
|
|
with pytest.raises(ValueError):
|
|
await run_node(node, inputs={"fn": "not a function", "args": {}})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CallForEach
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _RecordingFunctionWrapper:
|
|
"""Minimal stand-in for `FunctionWrapper` used to test CallForEach. We
|
|
can't easily build a real graph-backed wrapper that takes a runtime arg
|
|
in this scope, but `CallForEach` doesn't care about implementation —
|
|
it just checks `isinstance(fn, FunctionWrapper)`. So we patch the type
|
|
inheritance via subclassing FunctionWrapper directly."""
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_for_each_iterates_list_items():
|
|
"""Verify CallForEach calls a fn-wrapper once per list item, passing the
|
|
item under the configured argument name."""
|
|
received = []
|
|
|
|
class _RecordingWrapper(FunctionWrapper):
|
|
def __init__(self):
|
|
# Skip parent __init__ — we don't need a real graph.
|
|
pass
|
|
|
|
async def __call__(self, **kwargs):
|
|
received.append(kwargs)
|
|
return f"ok-{kwargs.get('item')}"
|
|
|
|
fn = _RecordingWrapper()
|
|
node = CallForEach()
|
|
out = await run_node(
|
|
node,
|
|
inputs={
|
|
"state": "STATE",
|
|
"fn": fn,
|
|
"items": ["a", "b", "c"],
|
|
"argument_name": "item",
|
|
},
|
|
)
|
|
assert out["state"] == "STATE"
|
|
assert out["results"] == ["ok-a", "ok-b", "ok-c"]
|
|
assert received == [{"item": "a"}, {"item": "b"}, {"item": "c"}]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_for_each_iterates_dict_values():
|
|
"""Dict items are converted to a list of values and iterated."""
|
|
received = []
|
|
|
|
class _W(FunctionWrapper):
|
|
def __init__(self):
|
|
pass
|
|
|
|
async def __call__(self, **kwargs):
|
|
received.append(kwargs.get("v"))
|
|
return None
|
|
|
|
node = CallForEach()
|
|
await run_node(
|
|
node,
|
|
inputs={
|
|
"state": "x",
|
|
"fn": _W(),
|
|
"items": {"k1": "v1", "k2": "v2"},
|
|
"argument_name": "v",
|
|
"copy_items": True,
|
|
},
|
|
)
|
|
assert sorted(received) == ["v1", "v2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_for_each_rejects_non_function_fn():
|
|
from talemate.game.engine.nodes.core import InputValueError
|
|
|
|
node = CallForEach()
|
|
with pytest.raises(InputValueError):
|
|
await run_node(
|
|
node,
|
|
inputs={
|
|
"state": "x",
|
|
"fn": "not a function",
|
|
"items": ["a"],
|
|
"argument_name": "item",
|
|
},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_for_each_rejects_non_collection_items():
|
|
from talemate.game.engine.nodes.core import InputValueError
|
|
|
|
class _W(FunctionWrapper):
|
|
def __init__(self):
|
|
pass
|
|
|
|
async def __call__(self, **kwargs):
|
|
return None
|
|
|
|
node = CallForEach()
|
|
with pytest.raises(InputValueError):
|
|
await run_node(
|
|
node,
|
|
inputs={
|
|
"state": "x",
|
|
"fn": _W(),
|
|
"items": "not a list or dict",
|
|
"argument_name": "item",
|
|
},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_for_each_rejects_blank_argument_name():
|
|
from talemate.game.engine.nodes.core import InputValueError
|
|
|
|
class _W(FunctionWrapper):
|
|
def __init__(self):
|
|
pass
|
|
|
|
async def __call__(self, **kwargs):
|
|
return None
|
|
|
|
node = CallForEach()
|
|
with pytest.raises(InputValueError):
|
|
await run_node(
|
|
node,
|
|
inputs={
|
|
"state": "x",
|
|
"fn": _W(),
|
|
"items": ["a"],
|
|
"argument_name": "",
|
|
},
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RunModule
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_module_executes_inner_graph():
|
|
"""Successful execution of a passed module sets done=True."""
|
|
from talemate.game.engine.nodes.core import GraphContext
|
|
|
|
ran: list[bool] = []
|
|
|
|
class _InnerSink(Node):
|
|
"""A sink with an `add_input` so it qualifies as an endpoint and runs.
|
|
|
|
Disconnected nodes are skipped by Graph.execute (it iterates only the
|
|
weakly-connected components of the edges graph), so we wire the entry
|
|
node's output into the sink's input.
|
|
"""
|
|
|
|
def __init__(self, **kw):
|
|
super().__init__(title="InnerSink", **kw)
|
|
|
|
def setup(self):
|
|
self.add_input("state")
|
|
|
|
async def run(self, state):
|
|
ran.append(True)
|
|
|
|
inner = Graph()
|
|
entry = Entry(title="entry")
|
|
sink = _InnerSink()
|
|
inner.add_node(entry)
|
|
inner.add_node(sink)
|
|
inner.connect(entry.outputs[0], sink.get_input_socket("state"))
|
|
|
|
node = RunModule()
|
|
node.set_property("module", inner)
|
|
|
|
outer_state = GraphState()
|
|
with GraphContext(outer_state=outer_state) as state:
|
|
await node.run(state)
|
|
out = {sock.name: sock.value for sock in node.outputs}
|
|
|
|
assert out["done"] is True
|
|
assert ran == [True]
|
|
|
|
await cleanup_pending_tasks()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_module_rejects_non_graph_module():
|
|
from talemate.game.engine.nodes.core import GraphContext
|
|
|
|
node = RunModule()
|
|
node.set_property("module", "not a graph")
|
|
with GraphContext(outer_state=GraphState()) as state:
|
|
with pytest.raises(ValueError):
|
|
await node.run(state)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_module_detects_infinite_loop():
|
|
"""If the module is already running on `state.outer.data['_in_run_module']`
|
|
(same module reference), RunModule must raise ValueError."""
|
|
from talemate.game.engine.nodes.core import GraphContext
|
|
|
|
inner = Graph()
|
|
|
|
outer_state = GraphState()
|
|
outer_state.data["_in_run_module"] = inner # claim it's running
|
|
|
|
node = RunModule()
|
|
node.set_property("module", inner)
|
|
with GraphContext(outer_state=outer_state) as state:
|
|
with pytest.raises(ValueError):
|
|
await node.run(state)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_module_wraps_inner_exception_as_module_error():
|
|
"""When the inner module raises a regular exception, RunModule wraps it
|
|
as a ModuleError after marking failed=str(exc)."""
|
|
from talemate.game.engine.nodes.core import GraphContext, ModuleError
|
|
|
|
class _Boom(Node):
|
|
def __init__(self, **kw):
|
|
super().__init__(title="Boom", **kw)
|
|
|
|
def setup(self):
|
|
self.add_input("state")
|
|
|
|
async def run(self, state):
|
|
raise RuntimeError("inner boom")
|
|
|
|
inner = Graph()
|
|
entry = Entry(title="entry")
|
|
boom = _Boom()
|
|
inner.add_node(entry)
|
|
inner.add_node(boom)
|
|
inner.connect(entry.outputs[0], boom.get_input_socket("state"))
|
|
|
|
node = RunModule()
|
|
node.set_property("module", inner)
|
|
|
|
with GraphContext(outer_state=GraphState()) as state:
|
|
with pytest.raises(ModuleError):
|
|
await node.run(state)
|
|
failed_value = node.get_output_socket("failed").value
|
|
|
|
assert "inner boom" in failed_value
|
|
|
|
await cleanup_pending_tasks()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Breakpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_breakpoint_skipped_in_non_creative_environment():
|
|
"""Breakpoint should pass through state and not raise outside creative."""
|
|
scene = MockScene()
|
|
scene.environment = "scene" # not "creative"
|
|
|
|
out = await run_node(
|
|
Breakpoint(),
|
|
scene=scene,
|
|
inputs={"state": "STATE", "active": True},
|
|
)
|
|
assert out["state"] == "STATE"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_breakpoint_skipped_when_inactive_in_creative():
|
|
"""Breakpoint with active=False is a passthrough even in creative."""
|
|
scene = MockScene()
|
|
scene.environment = "creative"
|
|
|
|
node = Breakpoint()
|
|
node.set_property("active", False)
|
|
|
|
out = await run_node(node, scene=scene, inputs={"state": "STATE"})
|
|
assert out["state"] == "STATE"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# UnpackException
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unpack_exception_emits_name_and_message():
|
|
exc = ExceptionWrapper(name="ValueError", message="something broke")
|
|
out = await run_node(UnpackException(), inputs={"exc": exc})
|
|
assert out["name"] == "ValueError"
|
|
assert out["message"] == "something broke"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unpack_exception_returns_silently_for_wrong_type():
|
|
"""When the input is not an ExceptionWrapper, run() logs and returns
|
|
without setting outputs (UNRESOLVED)."""
|
|
out = await run_node(
|
|
UnpackException(), inputs={"exc": {"name": "X", "message": "Y"}}
|
|
)
|
|
assert out["name"] is UNRESOLVED
|
|
assert out["message"] is UNRESOLVED
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ErrorHandler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_error_handler_is_isolated_and_never_runs():
|
|
"""ErrorHandler is marked _isolated and never_run by design — the
|
|
dispatcher must not pick it up; only `catch()` is invoked from the
|
|
Graph error path."""
|
|
handler = ErrorHandler()
|
|
assert ErrorHandler._isolated is True
|
|
assert handler.never_run is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_handler_catch_invokes_caught_function_with_exception_wrapper():
|
|
"""ErrorHandler.catch wires an ExceptionWrapper through the user-supplied
|
|
function (a FunctionWrapper) and returns its result."""
|
|
|
|
captured: list[ExceptionWrapper] = []
|
|
|
|
class _RecordingWrapper(FunctionWrapper):
|
|
def __init__(self):
|
|
pass
|
|
|
|
async def __call__(self, **kwargs):
|
|
captured.append(kwargs.get("exc"))
|
|
return "handled"
|
|
|
|
handler = ErrorHandler()
|
|
|
|
# The handler reads `fn_socket.source.node.run(state)` — that node must
|
|
# return the FunctionWrapper. We build a minimal Node that returns one.
|
|
|
|
class _FnSource(Node):
|
|
def __init__(self, **kw):
|
|
super().__init__(title="FnSource", **kw)
|
|
|
|
def setup(self):
|
|
self.add_output("fn", socket_type="function")
|
|
|
|
async def run(self, state):
|
|
return _RecordingWrapper()
|
|
|
|
fn_source = _FnSource()
|
|
graph = Graph()
|
|
graph.add_node(handler)
|
|
graph.add_node(fn_source)
|
|
graph.connect(
|
|
fn_source.get_output_socket("fn"),
|
|
handler.get_input_socket("fn"),
|
|
)
|
|
|
|
state = GraphState()
|
|
state.graph = graph
|
|
|
|
result = await handler.catch(state, RuntimeError("boom"))
|
|
assert result == "handled"
|
|
assert len(captured) == 1
|
|
assert isinstance(captured[0], ExceptionWrapper)
|
|
assert captured[0].name == "RuntimeError"
|
|
assert captured[0].message == "boom"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_handler_catch_returns_false_when_fn_not_a_wrapper():
|
|
"""If the user-supplied `fn` source returns something that isn't a
|
|
FunctionWrapper, `catch` logs and returns False."""
|
|
|
|
class _BadFnSource(Node):
|
|
def __init__(self, **kw):
|
|
super().__init__(title="BadFnSource", **kw)
|
|
|
|
def setup(self):
|
|
self.add_output("fn", socket_type="function")
|
|
|
|
async def run(self, state):
|
|
return "not a wrapper"
|
|
|
|
handler = ErrorHandler()
|
|
src = _BadFnSource()
|
|
graph = Graph()
|
|
graph.add_node(handler)
|
|
graph.add_node(src)
|
|
graph.connect(src.get_output_socket("fn"), handler.get_input_socket("fn"))
|
|
|
|
state = GraphState()
|
|
state.graph = graph
|
|
|
|
result = await handler.catch(state, RuntimeError("oops"))
|
|
assert result is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Function (Graph-level) — outputs on its run()
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_graph_run_emits_fn_and_name_outputs():
|
|
"""A Function (module-graph) sets the fn, name, allow_multiple_calls and
|
|
ai_callback output sockets when run."""
|
|
from talemate.game.engine.nodes.core import GraphContext
|
|
|
|
fn_graph = Function()
|
|
fn_graph.set_property("name", "my_function")
|
|
fn_graph.set_property("allow_multiple_calls", True)
|
|
|
|
with GraphContext() as state:
|
|
await fn_graph.run(state)
|
|
out = {sock.name: sock.value for sock in fn_graph.outputs}
|
|
|
|
assert out["name"] == "my_function"
|
|
assert out["allow_multiple_calls"] is True
|
|
assert isinstance(out["fn"], FunctionWrapper)
|
|
# ai_callback returns a focal.Callback
|
|
import talemate.game.focal as focal
|
|
|
|
assert isinstance(out["ai_callback"], focal.Callback)
|
|
assert out["ai_callback"].name == "my_function"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_graph_inputs_outputs_are_static():
|
|
"""Function.inputs is always [] (functions have no graph-level inputs);
|
|
outputs is a 4-tuple of fn/name/allow_multiple_calls/ai_callback."""
|
|
fn = Function()
|
|
assert fn.inputs == []
|
|
output_names = sorted(s.name for s in fn.outputs)
|
|
assert output_names == ["ai_callback", "allow_multiple_calls", "fn", "name"]
|