Files
talemate/tests/test_nodes_core.py
veguAI f5d41c04c8 0.37.0 (#267)
0.37.0

- **Director Planning** — Multi-step todo lists in director chat plus a Generate long progress action for multi-beat scene arcs.
- **Auto Narration** — Unified auto-narration replacing the old Narrate after Dialogue toggle, with a chance slider and weighted action mix.
- **LLM Prompt Templates Manager** — Dedicated UI tab for viewing, creating, editing, and deleting prompt templates.
- **Character Folders** — Collapsible folders in the World Editor character list, synced across linked scenes.
- **OpenAI Compatible TTS** — Connect any number of OpenAI-compatible TTS servers in parallel.
- **KoboldCpp TTS Auto-Setup** — KoboldCpp clients with a TTS model loaded register themselves as a TTS backend.
- **Model Testing Harness** — Bundled scene that runs basic capability tests against any connected LLM.

Plus 27 improvements and 28 bug fixes
2026-05-12 21:01:51 +03:00

2309 lines
70 KiB
Python

"""
Coverage-focused unit tests for talemate.game.engine.nodes.core.
These tests build small synthetic graphs and exercise public entry points
(Node.run, Graph.execute, Loop.execute, Listen.execute_from_event,
Trigger.run, etc.) so we cover:
- Socket/Property edge cases (missing graph_state, dunder methods,
generate_choices)
- Graph traversal helpers (set_node_references, set_socket_source_references,
ensure_connections, connect by string)
- Serialization (model_dump with/without SaveContext, _node_serialization_fields)
- Dispatch & error handling inside Graph._execute_inner / Loop.execute
- Listen and Trigger event hooks
- ModuleProperty.cast_value branches
- load_extended_components from a JSON file
- dynamic_node_import factory
No domain agents/scenes are required — every test instantiates real Node
subclasses and runs them through real Graph/Loop primitives.
"""
import json
from typing import Any, ClassVar
import pytest
import structlog
from talemate.game.engine.nodes.core import (
UNRESOLVED,
Comment,
Entry,
Graph,
GraphContext,
GraphState,
Group,
InputValueError,
Listen,
Loop,
LoopExit,
ModuleProperty,
ModuleStyle,
Node,
NodeBase,
NodeState,
NodeVerbosity,
Output,
Input,
PropertyField,
Route,
SaveContext,
Socket,
Stage,
StageExit,
StopGraphExecution,
Trigger,
Watch,
dynamic_node_import,
get_ancestors_with_forks,
get_type_class,
load_extended_components,
save_state,
validate_node,
)
from talemate.util.async_tools import cleanup_pending_tasks
log = structlog.get_logger(__name__)
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
class ValueNode(Node):
"""Emit a constant value on its `value` output."""
def __init__(self, value: Any = None, title: str = "ValueNode", **kwargs):
super().__init__(title=title, **kwargs)
self._emit = value
def setup(self):
self.add_output("value")
async def run(self, state: GraphState):
self.set_output_values({"value": self._emit})
class CaptureNode(Node):
"""Capture `value` input into ``self.captured``."""
captured: list = None
def __init__(self, title: str = "Capture", **kwargs):
super().__init__(title=title, **kwargs)
self.captured = []
def setup(self):
self.add_input("value")
async def run(self, state: GraphState):
self.captured.append(self.get_input_value("value"))
class RaisingNode(Node):
"""Raise the exception class supplied via ``self._exc``."""
def __init__(self, exc: Exception, title: str = "Raise", **kwargs):
super().__init__(title=title, **kwargs)
self._exc = exc
def setup(self):
self.add_input("trigger", optional=True)
self.add_output("done")
async def run(self, state: GraphState):
raise self._exc
def build_value_capture_graph(value):
"""Create a Graph: ValueNode -> CaptureNode wired on `value`."""
g = Graph()
src = ValueNode(value=value)
sink = CaptureNode()
g.add_node(src)
g.add_node(sink)
g.connect(src.outputs[0], sink.inputs[0])
return g, src, sink
# ---------------------------------------------------------------------------
# Module-level helpers (get_type_class, get_ancestors_with_forks,
# load_extended_components, dynamic_node_import)
# ---------------------------------------------------------------------------
class TestGetTypeClass:
def test_known_type_returns_class(self):
assert get_type_class("str") is str
assert get_type_class("int") is int
assert get_type_class("dict") is dict
def test_unknown_type_raises(self):
with pytest.raises(ValueError, match="Could not find class"):
get_type_class("not-a-type")
class TestGetAncestorsWithForks:
def test_includes_forked_branch(self):
"""A branch from a shared ancestor that does not lead to the target
is still included (this is the helper's whole point)."""
import networkx as nx
# A -> B -> D (target)
# A -> C (fork: doesn't reach D)
g = nx.DiGraph()
g.add_edge("A", "B")
g.add_edge("B", "D")
g.add_edge("A", "C")
result = get_ancestors_with_forks(g, "D")
# Direct ancestors
assert "A" in result
assert "B" in result
# Fork that nx.ancestors() would miss
assert "C" in result
# Target itself is not included
assert "D" not in result
def test_no_ancestors_returns_empty(self):
import networkx as nx
g = nx.DiGraph()
g.add_node("A")
assert get_ancestors_with_forks(g, "A") == set()
class TestLoadExtendedComponents:
def test_merges_nodes_edges_groups_comments_marking_inherited(self, tmp_path):
base_path = tmp_path / "base.json"
base = {
"nodes": {"n1": {"id": "n1", "title": "Base"}},
"edges": {"n1.out": ["n2.in"]},
"groups": [{"title": "G1"}],
"comments": [{"text": "C1"}],
}
base_path.write_text(json.dumps(base))
target = {
"nodes": {"n2": {"id": "n2", "title": "Target"}},
"edges": {},
"groups": [],
"comments": [],
}
load_extended_components(str(base_path), target)
# Inherited nodes are added with inherited=True
assert "n1" in target["nodes"]
assert target["nodes"]["n1"]["inherited"] is True
# Existing nodes are not overwritten
assert target["nodes"]["n2"]["title"] == "Target"
# Edges merged
assert target["edges"]["n1.out"] == ["n2.in"]
# Groups/comments inherited and stamped
assert target["groups"][0]["inherited"] is True
assert target["comments"][0]["inherited"] is True
def test_chained_extends(self, tmp_path):
"""When a base file itself extends another, the chain is loaded."""
deep = tmp_path / "deep.json"
deep.write_text(
json.dumps(
{
"nodes": {"deep_node": {"id": "deep_node"}},
"edges": {},
"groups": [],
"comments": [],
}
)
)
mid = tmp_path / "mid.json"
mid.write_text(
json.dumps(
{
"extends": str(deep),
"nodes": {"mid_node": {"id": "mid_node"}},
"edges": {},
"groups": [],
"comments": [],
}
)
)
target = {"nodes": {}, "edges": {}, "groups": [], "comments": []}
load_extended_components(str(mid), target)
assert "deep_node" in target["nodes"]
assert "mid_node" in target["nodes"]
assert target["nodes"]["deep_node"]["inherited"] is True
assert target["nodes"]["mid_node"]["inherited"] is True
class TestDynamicNodeImport:
def test_unknown_base_type_raises(self):
with pytest.raises(ValueError, match="Cannont import"):
dynamic_node_import(
{"base_type": "core/NotARealBaseType"},
"test/UnknownBase",
)
def test_creates_dynamic_class_in_provided_container(self):
container = {}
cls = dynamic_node_import(
{
"base_type": "core/Graph",
"title": "DynamicGraph",
"nodes": {},
"edges": {},
},
"test/DynamicGraphNode",
registry_container=container,
)
# Class registered into our isolated container, not the global NODES
assert "test/DynamicGraphNode" in container
assert getattr(cls, "__dynamic_imported__") is True
assert cls._base_type == "core/Graph"
# Name is taken from the trailing segment of the registry name
assert cls.__name__ == "DynamicGraphNode"
# The dynamic class is instantiable as a Graph subclass
instance = cls()
assert isinstance(instance, Graph)
# ---------------------------------------------------------------------------
# UNRESOLVED sentinel
# ---------------------------------------------------------------------------
def test_unresolved_is_falsy_and_repr_stable():
assert bool(UNRESOLVED()) is False
assert str(UNRESOLVED()) == "<UNRESOLVED>"
assert repr(UNRESOLVED()) == "<UNRESOLVED>"
# ---------------------------------------------------------------------------
# Socket
# ---------------------------------------------------------------------------
class TestSocketBehaviour:
def test_value_outside_graph_state_returns_unresolved(self):
node = Node(title="N")
sock = node.add_input("foo")
# No GraphContext active — there is no graph_state set
assert sock.value is UNRESOLVED
def test_setting_value_outside_graph_state_is_silent_noop(self):
node = Node(title="N")
sock = node.add_output("foo")
sock.value = 123 # must not raise
assert sock.value is UNRESOLVED
def test_deactivated_outside_graph_state_returns_true(self):
node = Node(title="N")
sock = node.add_input("foo")
# Outside the graph context, sockets are considered deactivated as a
# safety default (the runtime requires a state to track activation).
assert sock.deactivated is True
def test_setting_deactivated_outside_graph_state_is_silent_noop(self):
node = Node(title="N")
sock = node.add_input("foo")
sock.deactivated = False # must not raise
def test_value_within_graph_context(self):
node = Node(title="N")
sock = node.add_output("foo")
with GraphContext():
sock.value = "hello"
assert sock.value == "hello"
def test_value_through_source_lookup(self):
a = Node(title="A")
b = Node(title="B")
out = a.add_output("out")
inp = b.add_input("in")
inp.source = out
with GraphContext() as state:
state.set_node_socket_value(a, "out", 42)
# Reading from `inp` follows source -> A.out
assert inp.value == 42
def test_full_id_combines_node_and_socket(self):
n = Node(title="N", id="node-id")
s = n.add_input("sname")
assert s.full_id == "node-id.sname"
def test_str_with_node(self):
n = Node(title="MyNode")
s = n.add_input("foo")
assert str(s) == "MyNode.foo"
assert repr(s) == "MyNode.foo"
def test_str_without_node(self):
s = Socket(name="orphan")
assert str(s) == "orphan"
def test_eq_and_hash_use_id(self):
n = Node(title="N")
s1 = n.add_input("x")
# Two sockets with the same id compare equal and hash the same
s2 = Socket(name="x", id=s1.id)
assert s1 == s2
assert hash(s1) == hash(s2)
# Different id
s3 = n.add_input("y")
assert s1 != s3
def test_socket_as_bool_unresolved(self):
assert Socket.as_bool(UNRESOLVED) is False
assert Socket.as_bool(0) is False # bool(0)
assert Socket.as_bool("hello") is True
# ---------------------------------------------------------------------------
# PropertyField & RESERVED_PROPERTY_NAMES
# ---------------------------------------------------------------------------
class TestPropertyField:
def test_reserved_name_rejected(self):
with pytest.raises(ValueError, match="reserved"):
PropertyField(name="id", description="x", type="str")
def test_model_dump_without_generate_choices(self):
field = PropertyField(name="x", description="x", type="str", choices=["a", "b"])
data = field.model_dump()
assert data["choices"] == ["a", "b"]
def test_generate_choices_overrides_choices(self):
def gen():
return ["dynamic1", "dynamic2"]
field = PropertyField(
name="x",
description="x",
type="str",
choices=["static"],
generate_choices=gen,
)
data = field.model_dump()
assert data["choices"] == ["dynamic1", "dynamic2"]
# ---------------------------------------------------------------------------
# NodeBase
# ---------------------------------------------------------------------------
class TestNodeBaseHelpers:
def test_set_property_rejects_reserved_name(self):
n = Node(title="N")
with pytest.raises(ValueError, match="reserved"):
n.set_property("id", "anything")
def test_remove_input_and_output(self):
n = Node(title="N")
n.add_input("a")
n.add_input("b")
n.add_output("x")
n.add_output("y")
n.remove_input("a")
assert [s.name for s in n.inputs] == ["b"]
# Removing a non-existent socket is a no-op
n.remove_input("does-not-exist")
assert [s.name for s in n.inputs] == ["b"]
n.remove_output("y")
assert [s.name for s in n.outputs] == ["x"]
def test_field_definitions_includes_properties_and_meta_class(self):
# Stage has a Fields meta-class with `stage` PropertyField
s = Stage()
defs = s.field_definitions
assert "stage" in defs
assert defs["stage"].type == "int"
def test_field_definitions_for_unknown_property_falls_back(self):
n = Node(title="N")
n.set_property("count", 5)
defs = n.field_definitions
# `count` is not declared in any Fields meta — falls back to inferred type
assert "count" in defs
assert defs["count"].type == "int"
def test_handle_unresolved_properties_validator(self):
# The before-validator converts None / 'UNRESOLVED' string back to
# the UNRESOLVED sentinel. Call the classmethod directly because the
# public __init__ pops `properties` before pydantic validation runs.
out = NodeBase.handle_unresolved_properties(
{"properties": {"foo": "UNRESOLVED", "bar": None, "baz": 5}}
)
assert out["properties"]["foo"] is UNRESOLVED
assert out["properties"]["bar"] is UNRESOLVED
assert out["properties"]["baz"] == 5
def test_handle_unresolved_properties_validator_passthrough_for_non_dict(self):
# The validator is a no-op for inputs that are not a dict with
# `properties` set.
assert NodeBase.handle_unresolved_properties({}) == {}
# Also no-op for non-dict (the validator only mutates dict-shaped inputs)
assert NodeBase.handle_unresolved_properties("anything") == "anything"
def test_model_dump_replaces_unresolved_with_none(self):
# Set the UNRESOLVED sentinel directly so we hit the dump-time
# normalization branch.
n = Node(title="N")
n.properties["foo"] = UNRESOLVED
n.properties["bar"] = "ordinary"
data = n.model_dump()
assert data["properties"]["foo"] is None
assert data["properties"]["bar"] == "ordinary"
def test_default_title_falls_back_to_class_name_with_spaces(self):
# Title not provided -> derived from class name
class MyCustomThing(Node):
pass
n = MyCustomThing()
assert n.title == "My Custom Thing"
def test_eq_uses_id_and_returns_false_for_non_node(self):
n = Node(title="N")
assert (n == "not a node") is False
class TestRequireInput:
@pytest.mark.asyncio
async def test_require_input_raises_when_unresolved(self):
class StrictNode(Node):
def setup(self):
self.add_input("must_be_set")
async def run(self, state):
self.require_input("must_be_set")
n = StrictNode()
g = Graph()
g.add_node(n)
with pytest.raises(InputValueError):
with GraphContext() as state:
await n.run(state)
@pytest.mark.asyncio
async def test_require_input_returns_value_when_set(self):
class StrictNode(Node):
received: Any = None
def setup(self):
self.add_input("v")
async def run(self, state):
self.received = self.require_input("v")
n = StrictNode()
n.set_property("v", "hello")
with GraphContext():
await n.run(None)
assert n.received == "hello"
@pytest.mark.asyncio
async def test_require_input_treats_none_as_unset_by_default(self):
n = Node(title="N")
n.add_input("v")
n.set_property("v", None)
with GraphContext():
with pytest.raises(InputValueError):
n.require_input("v")
@pytest.mark.asyncio
async def test_require_input_with_none_is_set_true(self):
n = Node(title="N")
n.add_input("v")
n.set_property("v", None)
with GraphContext():
assert n.require_input("v", none_is_set=True) is None
class TestNormalizedInputValue:
def test_unresolved_becomes_none(self):
n = Node(title="N")
n.add_input("v")
with GraphContext():
assert n.normalized_input_value("v") is None
def test_set_value_returned(self):
n = Node(title="N")
n.add_input("v")
n.set_property("v", "x")
with GraphContext():
assert n.normalized_input_value("v") == "x"
class TestRequireNumberInput:
def _make_node(self, value):
n = Node(title="N")
n.add_input("num")
n.set_property("num", value)
return n
def test_string_int_is_converted(self):
n = self._make_node("42")
with GraphContext():
assert n.require_number_input("num", types=(int,)) == 42
def test_string_float_is_converted(self):
n = self._make_node("3.14")
with GraphContext():
assert n.require_number_input("num", types=(float, int)) == pytest.approx(
3.14
)
def test_invalid_string_raises(self):
n = self._make_node("not-a-number")
with GraphContext():
with pytest.raises(InputValueError, match="Invalid number"):
n.require_number_input("num")
def test_non_number_value_raises(self):
n = self._make_node([1, 2])
with GraphContext():
with pytest.raises(InputValueError, match="must be a number"):
n.require_number_input("num")
def test_int_passes_through(self):
n = self._make_node(7)
with GraphContext():
assert n.require_number_input("num") == 7
# ---------------------------------------------------------------------------
# NodeState
# ---------------------------------------------------------------------------
class TestNodeState:
def _state_for(self, node):
with GraphContext() as state:
ns = NodeState(node=node, state=state)
return ns
def test_eq_compares_node_id(self):
n1 = Node(title="N", id="abc")
n2 = Node(title="N2", id="abc")
ns1 = self._state_for(n1)
ns2 = self._state_for(n2)
assert ns1 == ns2
def test_eq_against_non_nodestate_returns_false(self):
n1 = Node(title="N")
ns1 = self._state_for(n1)
# Hits the AttributeError branch
assert (ns1 == "not a node state") is False
def test_lt_gt_use_node_id(self):
n1 = Node(title="N", id="aaa")
n2 = Node(title="N", id="bbb")
ns1 = self._state_for(n1)
ns2 = self._state_for(n2)
assert ns1 < ns2
assert ns2 > ns1
def test_lt_against_non_nodestate_returns_notimplemented(self):
n1 = Node(title="N")
ns1 = self._state_for(n1)
assert ns1.__lt__("nope") is NotImplemented
assert ns1.__gt__("nope") is NotImplemented
def test_hash_uses_node_id(self):
n1 = Node(title="N", id="hash-id")
ns1 = self._state_for(n1)
assert hash(ns1) == hash("hash-id")
def test_str_and_repr(self):
n = Node(title="N", id="ident")
ns = self._state_for(n)
assert "ident" in str(ns)
assert "ident" in repr(ns)
def test_flattened_truncates_long_values(self):
n = Node(title="N", id="x")
n.add_input("a")
n.set_property("p", "y" * 1000)
ns = self._state_for(n)
ns.input_values = {"a": "x" * 5000}
ns.output_values = {}
flat = ns.flattened
# reprlib truncates long strings under maxstring=255
assert len(flat["input_values"]["a"]) <= 300
assert flat["node_id"] == "x"
class TestGraphStateFlattened:
def test_flattened_returns_stack_dump(self):
with GraphContext() as state:
n = Node(title="N", id="abc")
ns = NodeState(node=n, state=state)
state.stack.append(ns)
data = state.flattened
assert "stack" in data
assert len(data["stack"]) == 1
assert data["stack"][0]["node_id"] == "abc"
def test_flattened_handles_circular_references(self):
"""If the stack contains something that confuses repr, the
helper clears the stack and returns an empty list rather than
crashing."""
class BadNodeState:
def __init__(self):
pass
@property
def flattened(self):
raise RuntimeError("circular!")
state = GraphState()
state.stack = [BadNodeState()]
result = state.flattened
assert result == {"stack": []}
# Stack is cleared as a recovery
assert state.stack == []
# ---------------------------------------------------------------------------
# GraphState set/get helpers
# ---------------------------------------------------------------------------
class TestGraphStateHelpers:
def test_node_property_round_trip(self):
n = Node(title="N")
st = GraphState()
st.set_node_property(n, "k", "v")
assert st.get_node_property(n, "k") == "v"
def test_node_property_falls_back_to_property_then_unresolved(self):
n = Node(title="N")
n.set_property("local_only", "from-node")
st = GraphState()
# Not set in state -> falls back to node.properties
assert st.get_node_property(n, "local_only") == "from-node"
# Not set anywhere -> UNRESOLVED
assert st.get_node_property(n, "missing") is UNRESOLVED
def test_node_socket_value_round_trip(self):
n = Node(title="N")
st = GraphState()
st.set_node_socket_value(n, "out", 99)
assert st.get_node_socket_value(n, "out") == 99
def test_node_socket_value_default_unresolved(self):
n = Node(title="N")
st = GraphState()
assert st.get_node_socket_value(n, "no-such-socket") is UNRESOLVED
def test_node_socket_state_round_trip(self):
n = Node(title="N")
st = GraphState()
# default false
assert st.get_node_socket_state(n, "x") is False
st.set_node_socket_state(n, "x", True)
assert st.get_node_socket_state(n, "x") is True
class TestSaveContext:
def test_save_state_within_context(self):
# Outside the context, save_state has no value
with pytest.raises(LookupError):
save_state.get()
with SaveContext():
assert save_state.get() is True
# Reset on exit
with pytest.raises(LookupError):
save_state.get()
# ---------------------------------------------------------------------------
# Graph: connect, set_node_references, ensure_connections, model_dump
# ---------------------------------------------------------------------------
class TestGraphWiring:
def test_connect_with_socket_string_ids(self):
g = Graph()
a = ValueNode(value=1, title="A")
b = CaptureNode(title="B")
g.add_node(a)
g.add_node(b)
# connect by socket string id, exercises the str -> Socket lookup
# branch.
g.connect(a.outputs[0].id, b.inputs[0].id)
assert b.inputs[0].source.id == a.outputs[0].id
assert (
f"{a.id}.{a.outputs[0].name}" in g.edges
and f"{b.id}.{b.inputs[0].name}" in g.edges[f"{a.id}.{a.outputs[0].name}"]
)
def test_connect_with_invalid_string_id_logs_and_returns(self):
g = Graph()
a = ValueNode(value=1)
g.add_node(a)
# Lookup against unknown id raises KeyError
with pytest.raises(KeyError):
g.connect("no-such-socket-id", a.outputs[0].id)
def test_connect_dedup_avoids_duplicate_edge(self):
g, src, sink = build_value_capture_graph(1)
# Reconnecting the same edge should be a no-op (set semantics on
# the input list)
g.connect(src.outputs[0], sink.inputs[0])
edge_key = f"{src.id}.value"
assert g.edges[edge_key].count(f"{sink.id}.value") == 1
def test_set_node_references_populates_full_id_lookup(self):
"""set_node_references re-keys self.sockets by full_id and assigns
socket.node back to its parent node."""
g = Graph()
a = ValueNode(value=1)
b = CaptureNode()
g.add_node(a)
g.add_node(b)
# Reset the node refs to simulate a freshly deserialized graph
for node in (a, b):
for socket in node.inputs + node.outputs:
socket.node = None
result = g.set_node_references()
assert result is g # chainable
# Sockets keyed by full_id
assert f"{a.id}.value" in g.sockets
assert f"{b.id}.value" in g.sockets
# node back-reference restored
assert a.outputs[0].node is a
assert b.inputs[0].node is b
def test_reinitialize_restores_source_references_via_ensure_connections(self):
"""`reinitialize()` rebuilds source pointers, socket lookups, and
ensures connections."""
g = Graph()
a = ValueNode(value=1)
b = CaptureNode()
g.add_node(a)
g.add_node(b)
g.connect(a.outputs[0], b.inputs[0])
# Strip the source pointer to simulate a re-loaded graph
b.inputs[0].source = None
g.reinitialize()
# ensure_connections rewires the source from the edges dict
assert b.inputs[0].source is not None
assert b.inputs[0].source.full_id == f"{a.id}.value"
def test_ensure_connections_warns_for_missing_input_socket(self, caplog):
g = Graph()
a = ValueNode(value=1)
b = CaptureNode()
g.add_node(a)
g.add_node(b)
# Edge points at an input socket that does not exist on b
g.edges[f"{a.id}.value"] = [f"{b.id}.does_not_exist"]
g.set_node_references()
# Should not raise — should emit a warning and continue
g.ensure_connections()
assert b.inputs[0].source is None # nothing was connected
def test_node_lookup_helper(self):
g, src, sink = build_value_capture_graph(1)
assert g.node(src.id) is src
assert g.node(sink.id) is sink
def test_get_input_socket_returns_none_when_missing(self):
n = Node(title="N")
n.add_input("a")
assert n.get_input_socket("a") is not None
assert n.get_input_socket("missing") is None
def test_get_output_socket_returns_none_when_missing(self):
n = Node(title="N")
n.add_output("a")
assert n.get_output_socket("a") is not None
assert n.get_output_socket("missing") is None
# ---------------------------------------------------------------------------
# Graph: build, execute, callback wiring
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_graph_execute_propagates_value():
g, src, sink = build_value_capture_graph("hello")
await g.execute()
await cleanup_pending_tasks()
assert sink.captured == ["hello"]
@pytest.mark.asyncio
async def test_graph_execute_runs_user_callbacks_after_state_callbacks():
g, _, sink = build_value_capture_graph("x")
order = []
async def on_state(state):
order.append("graph-callback")
async def on_user(state):
order.append("user-callback")
g.callbacks.append(on_state)
await g.execute(callbacks=[on_user])
await cleanup_pending_tasks()
assert sink.captured == ["x"]
assert order == ["graph-callback", "user-callback"]
@pytest.mark.asyncio
async def test_graph_execute_with_state_values():
"""state_values must be merged into state.data so nodes can read them."""
captured_data = {}
class ReadDataNode(Node):
def setup(self):
self.add_input("trigger")
async def run(self, state):
captured_data["preset"] = state.data.get("preset")
# Need an entry node with an outbound edge — Graph.execute() iterates
# over weakly_connected_components, so disconnected nodes are skipped.
entry = Entry()
n = ReadDataNode()
g = Graph()
g.add_node(entry)
g.add_node(n)
g.connect(entry.outputs[0], n.inputs[0])
await g.execute(state_values={"preset": "hello"})
await cleanup_pending_tasks()
assert captured_data["preset"] == "hello"
@pytest.mark.asyncio
async def test_graph_cycle_raises():
"""Building a graph with a cycle should raise on execute()."""
g = Graph()
a = Node(title="A")
b = Node(title="B")
a_out = a.add_output("o")
a_in = a.add_input("i")
b_out = b.add_output("o")
b_in = b.add_input("i")
g.add_node(a)
g.add_node(b)
g.connect(a_out, b_in)
g.connect(b_out, a_in)
with pytest.raises(ValueError, match="cycles"):
await g.execute()
@pytest.mark.asyncio
async def test_graph_execute_to_node_raises_when_node_missing():
g, src, sink = build_value_capture_graph("x")
other = Node(title="Detached")
with pytest.raises(ValueError, match="not found in graph"):
await g.execute_to_node(other)
@pytest.mark.asyncio
async def test_graph_execute_to_node_runs_partial():
"""execute_to_node should only run the ancestors of the target."""
class TrackNode(ValueNode):
ran: bool = False
async def run(self, state):
self.ran = True
await super().run(state)
g = Graph()
a = TrackNode(value=1, title="A")
b = CaptureNode(title="B")
c = TrackNode(value=2, title="C-extra")
g.add_node(a)
g.add_node(b)
g.add_node(c)
g.connect(a.outputs[0], b.inputs[0])
# `c` is detached — it must NOT run when stopping at `b`
await g.execute_to_node(b)
await cleanup_pending_tasks()
assert a.ran is True
assert c.ran is False
@pytest.mark.asyncio
async def test_get_nodes_with_filter():
g, src, sink = build_value_capture_graph("x")
only_value_nodes = await g.get_nodes(lambda n: isinstance(n, ValueNode))
assert only_value_nodes == [src]
@pytest.mark.asyncio
async def test_get_node_unique_raises_on_duplicates():
g = Graph()
a = ValueNode(value=1, title="A")
b = ValueNode(value=2, title="B")
g.add_node(a)
g.add_node(b)
with pytest.raises(ValueError, match="Multiple nodes"):
await g.get_node(lambda n: isinstance(n, ValueNode))
@pytest.mark.asyncio
async def test_get_node_returns_none_when_no_match():
g, src, sink = build_value_capture_graph("x")
assert await g.get_node(lambda n: False) is None
# ---------------------------------------------------------------------------
# Graph: get_nodes_connected_to (uses get_ancestors_with_forks)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_nodes_connected_to_includes_forks():
g = Graph()
a = Node(title="A")
b = Node(title="B")
c = Node(title="C")
d = Node(title="D")
a_out1 = a.add_output("o1")
a_out2 = a.add_output("o2")
b_in = b.add_input("i")
b_out = b.add_output("o")
c_in = c.add_input("i")
d_in = d.add_input("i")
g.add_node(a)
g.add_node(b)
g.add_node(c)
g.add_node(d)
# A->B->D and A->C (fork)
g.connect(a_out1, b_in)
g.connect(b_out, d_in)
g.connect(a_out2, c_in)
ancestors_of_d = await g.get_nodes_connected_to(d)
titles = sorted(n.title for n in ancestors_of_d)
# A and B are direct ancestors; C is a fork from a shared ancestor
assert titles == ["A", "B", "C"]
# ---------------------------------------------------------------------------
# Graph serialization
# ---------------------------------------------------------------------------
class TestGraphSerialization:
def test_serialize_nodes_filters_by_node_serialization_fields(self):
g, src, sink = build_value_capture_graph(1)
data = g.model_dump()
for node_id, node_data in data["nodes"].items():
# Only fields in _node_serialization_fields should appear
assert set(node_data.keys()).issubset(g._node_serialization_fields)
def test_save_state_drops_inherited_nodes_and_edges(self):
# Build a graph where one node is marked inherited
g = Graph()
keep = ValueNode(value=1, title="Keep")
drop = ValueNode(value=2, title="Drop")
sink = CaptureNode(title="Sink")
drop.inherited = True
g.add_node(keep)
g.add_node(drop)
g.add_node(sink)
g.connect(keep.outputs[0], sink.inputs[0])
# Edge from inherited node — must be dropped
g.connect(drop.outputs[0], sink.inputs[0])
# Add an inherited group/comment too
g.groups.append(Group(title="kept group"))
g.groups.append(Group(title="inherited group", inherited=True))
g.comments.append(Comment(text="kept comment"))
g.comments.append(Comment(text="inherited comment", inherited=True))
# Without SaveContext: full data is returned
full = g.model_dump()
assert drop.id in full["nodes"]
assert any(c["text"] == "inherited comment" for c in full["comments"])
# With SaveContext: inherited stuff is filtered
with SaveContext():
saved = g.model_dump()
assert drop.id not in saved["nodes"]
assert keep.id in saved["nodes"]
# Edges referencing the dropped node are also gone
assert all(drop.id not in edge_key for edge_key in saved["edges"].keys())
for input_ids in saved["edges"].values():
for input_id in input_ids:
assert drop.id not in input_id
# Inherited groups and comments are filtered
assert all(g["title"] != "inherited group" for g in saved["groups"])
assert all(c["text"] != "inherited comment" for c in saved["comments"])
@pytest.mark.asyncio
async def test_clone_yields_independent_graph(self):
# Clone reconstructs a graph from its JSON dump, which means each
# node must be registered so validate_node can locate its class.
g = Graph()
a = Route()
b = Route()
g.add_node(a)
g.add_node(b)
g.connect(a.outputs[0], b.inputs[0])
clone = await g.clone()
# Same nodes
assert set(clone.nodes.keys()) == set(g.nodes.keys())
# But independent: editing the clone doesn't touch the original
clone.title = "renamed"
assert g.title != "renamed"
# ---------------------------------------------------------------------------
# Graph: input/output node mapping (Input / Output / ModuleProperty)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_graph_routes_input_node_value_from_outer_state():
"""Inner graph reads value of an Input node from the outer state."""
class OuterEntry(Node):
def setup(self):
self.add_output("payload")
async def run(self, state):
self.set_output_values({"payload": "from-outside"})
inner = Graph()
in_node = Input()
in_node.set_property("input_name", "payload")
inner.add_node(in_node)
sink = CaptureNode(title="Sink")
inner.add_node(sink)
inner.connect(in_node.outputs[0], sink.inputs[0])
# Mark the inner module's input list (must be reset before computed_field
# is read again). Force recomputation so inner.inputs picks up the new
# Input node.
if hasattr(inner, "_inputs"):
delattr(inner, "_inputs")
# Outer graph wires its own producer into the inner graph's input
outer = Graph()
src = OuterEntry(title="Outer")
outer.add_node(src)
outer.add_node(inner)
outer.connect(src.outputs[0], inner.inputs[0])
await outer.execute()
await cleanup_pending_tasks()
assert sink.captured == ["from-outside"]
@pytest.mark.asyncio
async def test_graph_routes_output_node_value_to_outer_state():
"""An Output node inside an inner graph propagates its value out."""
class OuterCapture(Node):
captured_values: list = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.captured_values = []
def setup(self):
self.add_input("value")
async def run(self, state):
self.captured_values.append(self.get_input_value("value"))
inner = Graph()
src = ValueNode(value="inner-value", title="src")
out_node = Output()
out_node.set_property("output_name", "result")
inner.add_node(src)
inner.add_node(out_node)
inner.connect(src.outputs[0], out_node.inputs[0])
if hasattr(inner, "_outputs"):
delattr(inner, "_outputs")
outer = Graph()
sink = OuterCapture()
outer.add_node(inner)
outer.add_node(sink)
outer.connect(inner.outputs[0], sink.inputs[0])
await outer.execute()
await cleanup_pending_tasks()
assert sink.captured_values == ["inner-value"]
# ---------------------------------------------------------------------------
# ModuleProperty — cast_value branches
# ---------------------------------------------------------------------------
class TestModulePropertyCastValue:
def _make(self, prop_type: str, default: Any = UNRESOLVED) -> ModuleProperty:
mp = ModuleProperty()
mp.set_property("property_type", prop_type)
mp.set_property("default", default)
return mp
def test_unresolved_uses_default(self):
mp = self._make("str", default="fallback")
assert mp.cast_value(UNRESOLVED) == "fallback"
def test_str_cast(self):
assert self._make("str").cast_value(123) == "123"
def test_text_cast(self):
assert self._make("text").cast_value(123) == "123"
def test_bool_true_strings(self):
mp = self._make("bool")
for v in ("true", "yes", "1", "TRUE", "Yes"):
assert mp.cast_value(v) is True
def test_bool_false_strings(self):
mp = self._make("bool")
for v in ("false", "no", "0", "FALSE", "No"):
assert mp.cast_value(v) is False
def test_bool_other_strings_use_python_bool(self):
# Non-empty non-special string -> truthy
assert self._make("bool").cast_value("anything-else") is True
def test_bool_non_string_uses_bool(self):
assert self._make("bool").cast_value(0) is False
assert self._make("bool").cast_value(1) is True
assert self._make("bool").cast_value([1]) is True
def test_int_cast(self):
assert self._make("int").cast_value("7") == 7
def test_float_cast(self):
assert self._make("float").cast_value("1.5") == pytest.approx(1.5)
def test_unknown_type_falls_back_to_str(self):
# Hitting the final `return str(value)` branch
assert self._make("custom-type").cast_value(99) == "99"
@pytest.mark.asyncio
async def test_graph_module_properties_aggregated_from_nodes():
g = Graph()
mp1 = ModuleProperty()
mp1.set_property("property_name", "alpha")
mp1.set_property("property_type", "str")
mp1.set_property("default", "x")
mp1.set_property("num", 0)
mp1.set_property("choices", [])
mp2 = ModuleProperty()
mp2.set_property("property_name", "beta")
mp2.set_property("property_type", "int")
mp2.set_property("default", 5)
mp2.set_property("num", 1)
mp2.set_property("choices", [])
g.add_node(mp1)
g.add_node(mp2)
props = g.module_properties
assert set(props.keys()) == {"alpha", "beta"}
assert props["alpha"].type == "str"
assert props["beta"].type == "int"
def test_graph_style_returns_module_style_when_present():
g = Graph()
style_node = ModuleStyle()
style_node.set_property("node_color", "#ff0000")
style_node.set_property("title_color", "#00ff00")
style_node.set_property("auto_title", "")
style_node.set_property("icon", "")
g.add_node(style_node)
s = g.style
assert s is not None
assert s.node_color == "#ff0000"
assert s.title_color == "#00ff00"
def test_graph_style_returns_none_when_absent():
g = Graph()
g.add_node(ValueNode(value=1))
assert g.style is None
# ---------------------------------------------------------------------------
# Graph stage priority ordering
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stage_priority_orders_chains():
"""Chains with lower stage priority run before higher ones."""
order = []
class OrderedNode(Node):
def __init__(self, label, **kwargs):
super().__init__(title=label, **kwargs)
self._label = label
def setup(self):
self.add_input("trigger", optional=True)
self.add_output("done")
async def run(self, state):
order.append(self._label)
self.set_output_values({"done": True})
g = Graph()
# Chain 1 (higher stage)
s1 = Stage()
s1.set_property("stage", 5)
n1 = OrderedNode("late")
g.add_node(s1)
g.add_node(n1)
g.connect(s1.outputs[0], n1.inputs[0])
# Chain 2 (lower stage)
s2 = Stage()
s2.set_property("stage", 1)
n2 = OrderedNode("early")
g.add_node(s2)
g.add_node(n2)
g.connect(s2.outputs[0], n2.inputs[0])
await g.execute()
await cleanup_pending_tasks()
assert order == ["early", "late"]
# ---------------------------------------------------------------------------
# Stage exit handling inside a chain
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stage_exit_breaks_current_chain_only():
"""A StageExit in one chain stops that chain but the next still runs."""
runs = []
class ChainOneA(Node):
def setup(self):
self.add_output("done")
async def run(self, state):
runs.append("chain1-a")
class StageExitNode(Node):
def setup(self):
self.add_input("trigger", optional=True)
self.add_output("done")
async def run(self, state):
raise StageExit()
class ChainTwo(Node):
def setup(self):
self.add_output("done")
async def run(self, state):
runs.append("chain2")
g = Graph()
s1 = Stage()
s1.set_property("stage", 0)
a = ChainOneA(title="a")
boom = StageExitNode(title="boom")
s2 = Stage()
s2.set_property("stage", 1)
c = ChainTwo(title="c")
g.add_node(s1)
g.add_node(a)
g.add_node(boom)
g.add_node(s2)
g.add_node(c)
# Chain 1: s1 -> a -> boom (boom raises StageExit, halting chain 1)
g.connect(s1.outputs[0], a.inputs[0]) if False else None
# Actually wire chain 1 explicitly
a_in = a.add_input("trigger")
g.connect(s1.outputs[0], a_in)
g.connect(a.outputs[0], boom.inputs[0])
# Chain 2: s2 -> c
c_in = c.add_input("trigger")
g.connect(s2.outputs[0], c_in)
await g.execute()
await cleanup_pending_tasks()
# chain1-a ran; chain2 ran. The StageExit terminated chain 1 inside the
# boom node, but chain 2 was unaffected.
assert "chain1-a" in runs
assert "chain2" in runs
@pytest.mark.asyncio
async def test_stop_graph_execution_halts_entire_graph():
"""StopGraphExecution caught at the inner level halts execution silently."""
runs = []
class StopperNode(Node):
def setup(self):
self.add_output("done")
async def run(self, state):
runs.append("stop")
raise StopGraphExecution()
class NeverRunsInChain1(Node):
def setup(self):
self.add_input("trigger")
async def run(self, state):
runs.append("never1")
class ShouldNotRunChain2(Node):
def setup(self):
self.add_output("done")
async def run(self, state):
runs.append("never2")
g = Graph()
stop = StopperNode(title="stop")
after = NeverRunsInChain1(title="after")
other = ShouldNotRunChain2(title="other")
g.add_node(stop)
g.add_node(after)
g.add_node(other)
g.connect(stop.outputs[0], after.inputs[0])
await g.execute()
await cleanup_pending_tasks()
assert runs == ["stop"]
# ---------------------------------------------------------------------------
# Error handler nodes (catch / attempt_catch_with_node_error_handler)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_unhandled_exception_calls_error_handlers_list():
"""Errors that aren't caught by node-level handlers run the
Graph.error_handlers list."""
captured = []
class BoomNode(Node):
def setup(self):
self.add_input("trigger")
async def run(self, state):
raise RuntimeError("boom-boom")
async def handler(state, exc):
captured.append((type(exc).__name__, str(exc)))
g = Graph()
entry = Entry()
boom = BoomNode(title="boom")
g.add_node(entry)
g.add_node(boom)
g.connect(entry.outputs[0], boom.inputs[0])
g.error_handlers.append(handler)
with pytest.raises(RuntimeError, match="boom-boom"):
await g.execute()
await cleanup_pending_tasks()
assert captured == [("RuntimeError", "boom-boom")]
@pytest.mark.asyncio
async def test_handle_error_swallows_handler_exceptions():
"""A misbehaving handler must not mask the original exception."""
class BoomNode(Node):
def setup(self):
self.add_input("trigger")
async def run(self, state):
raise RuntimeError("original")
async def bad_handler(state, exc):
raise ValueError("handler-died")
g = Graph()
entry = Entry()
boom = BoomNode(title="boom")
g.add_node(entry)
g.add_node(boom)
g.connect(entry.outputs[0], boom.inputs[0])
g.error_handlers.append(bad_handler)
# Original exception still propagates
with pytest.raises(RuntimeError, match="original"):
await g.execute()
await cleanup_pending_tasks()
# ---------------------------------------------------------------------------
# Graph.reset / reset_sockets / reset_ephemeral_properties
# ---------------------------------------------------------------------------
def test_reset_clears_socket_values_and_deactivation():
g, src, sink = build_value_capture_graph(1)
with GraphContext():
src.outputs[0].value = 42
src.outputs[0].deactivated = True
assert src.outputs[0].value == 42
# reset_sockets nullifies value and deactivation
g.reset_sockets()
assert src.outputs[0].value is UNRESOLVED
assert src.outputs[0].deactivated is False
def test_reset_ephemeral_property_resets_to_default():
class EphemeralPropNode(Node):
class Fields:
cache = PropertyField(
name="cache",
description="ephemeral cache",
type="str",
default="default-val",
ephemeral=True,
)
n = EphemeralPropNode()
n.set_property("cache", "current-value")
g = Graph()
g.add_node(n)
g.reset_ephemeral_properties()
assert n.get_property("cache") == "default-val"
# ---------------------------------------------------------------------------
# Loop execution paths
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_loop_continue_skips_rest_of_iteration_chain():
"""A LoopContinue raised from a node aborts the current chain loop iteration
but the loop body keeps running. We pair it with an exit_condition that
fires after the node has incremented the counter so the loop exits
deterministically."""
iteration_count = {"n": 0}
class IncrementThenContinueNode(Node):
def setup(self):
self.add_input("trigger")
self.add_output("done")
async def run(self, state):
iteration_count["n"] += 1
# Stop iterating once we've run twice
if iteration_count["n"] >= 2:
raise LoopExit()
from talemate.game.engine.nodes.core import LoopContinue
raise LoopContinue()
entry = Entry(title="entry")
body = IncrementThenContinueNode(title="body")
loop = Loop()
loop.add_node(entry)
loop.add_node(body)
loop.connect(entry.outputs[0], body.inputs[0])
outer_entry = Entry(title="outer-entry")
outer = Graph()
outer.add_node(outer_entry)
outer.add_node(loop)
outer.connect(outer_entry.outputs[0], loop.inputs[0])
callback_seen = []
async def cb(state):
callback_seen.append(True)
loop.callbacks.append(cb)
await outer.execute()
await cleanup_pending_tasks()
assert iteration_count["n"] == 2
# finally-block fires once
assert callback_seen == [True]
@pytest.mark.asyncio
async def test_loop_exit_terminates_immediately_without_callbacks_after_break():
"""LoopExit returns from the loop completely (not from the iteration)."""
runs = []
class ExitNode(Node):
def setup(self):
self.add_input("trigger")
async def run(self, state):
runs.append("exit")
raise LoopExit()
entry = Entry()
exit_node = ExitNode(title="exitnode")
loop = Loop()
loop.add_node(entry)
loop.add_node(exit_node)
loop.connect(entry.outputs[0], exit_node.inputs[0])
outer_entry = Entry(title="outer-entry")
outer = Graph()
outer.add_node(outer_entry)
outer.add_node(loop)
outer.connect(outer_entry.outputs[0], loop.inputs[0])
cb_called = []
async def cb(state):
cb_called.append(True)
loop.callbacks.append(cb)
await outer.execute()
await cleanup_pending_tasks()
assert runs == ["exit"]
# finally-block runs callbacks
assert cb_called == [True]
@pytest.mark.asyncio
async def test_loop_exit_condition_terminates_loop():
"""exit_condition checked after each node's run."""
counter = {"n": 0}
class IncNode(Node):
def setup(self):
self.add_input("trigger")
self.add_output("done")
async def run(self, state):
counter["n"] += 1
entry = Entry()
inc = IncNode(title="inc")
loop = Loop(exit_condition=lambda state: counter["n"] >= 3)
loop.add_node(entry)
loop.add_node(inc)
loop.connect(entry.outputs[0], inc.inputs[0])
outer_entry = Entry(title="outer-entry")
outer = Graph()
outer.add_node(outer_entry)
outer.add_node(loop)
outer.connect(outer_entry.outputs[0], loop.inputs[0])
await outer.execute()
await cleanup_pending_tasks()
assert counter["n"] == 3
@pytest.mark.asyncio
async def test_loop_cycle_raises():
loop = Loop()
a = Node(title="A")
b = Node(title="B")
a_out = a.add_output("o")
a_in = a.add_input("i")
b_out = b.add_output("o")
b_in = b.add_input("i")
loop.add_node(a)
loop.add_node(b)
loop.connect(a_out, b_in)
loop.connect(b_out, a_in)
with pytest.raises(ValueError, match="cycles"):
# outer_state isn't strictly needed if we never get past the cycle check
await loop.execute(outer_state=GraphState())
# ---------------------------------------------------------------------------
# Listen / Trigger event nodes
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_listen_execute_from_event_runs_inner_graph():
"""Listen.execute_from_event populates state.data['event'] and executes
its body."""
captured = {}
class CaptureEventNode(Node):
def setup(self):
# Endpoint node — no outputs — always available
self.add_input("trigger", optional=True)
async def run(self, state):
captured["event"] = state.data.get("event")
listen = Listen(title="listen")
listen.set_property("event_name", "my_event")
entry = Entry()
n = CaptureEventNode()
listen.add_node(entry)
listen.add_node(n)
listen.connect(entry.outputs[0], n.inputs[0])
sentinel = object()
# execute_from_event needs an active GraphContext OR a scene with
# nodegraph_state. We provide an active GraphContext.
with GraphContext():
await listen.execute_from_event(sentinel)
assert captured["event"] is sentinel
@pytest.mark.asyncio
async def test_listen_execute_from_event_failsafe_skips_recent_failure():
"""If Listen recently failed, the next call within ~1.3s is skipped."""
listen = Listen(title="listen")
listen.set_property("event_name", "evt")
listen._failed = __import__("time").time()
# Should silently skip and clear the marker
with GraphContext():
await listen.execute_from_event("ignored")
assert listen._failed is None
@pytest.mark.asyncio
async def test_listen_execute_from_event_outside_state_returns_silently():
"""If there's no active graph_state and no active scene, the helper logs
and returns instead of raising."""
listen = Listen(title="listen")
listen.set_property("event_name", "evt")
# No GraphContext, no active scene with nodegraph_state — the helper should
# log an error and return None.
result = await listen.execute_from_event("payload")
assert result is None
@pytest.mark.asyncio
async def test_trigger_run_emits_signal():
"""Trigger.run should look up the named signal and send the event."""
# Create a custom Trigger subclass with a real make_event_object
class MyTrigger(Trigger):
def make_event_object(self, state):
return {"hello": "world"}
import talemate.emit.async_signals as async_signals
signal_name = "test_trigger_signal__core"
async_signals.register(signal_name)
received = []
async def listener(event):
received.append(event)
async_signals.get(signal_name).connect(listener)
try:
t = MyTrigger()
t.set_property("event_name", signal_name)
# Execute through the graph so state is set up properly. Wire an
# entry into Trigger.trigger and a sink onto Trigger.event so neither
# input nor output is left unconnected (which check_is_available
# treats as deactivated).
g = Graph()
entry = Entry()
sink = CaptureNode()
g.add_node(entry)
g.add_node(t)
g.add_node(sink)
g.connect(entry.outputs[0], t.get_input_socket("trigger"))
g.connect(t.get_output_socket("event"), sink.inputs[0])
await g.execute()
await cleanup_pending_tasks()
assert received == [{"hello": "world"}]
# Trigger output also exposes the event object
assert sink.captured == [{"hello": "world"}]
finally:
async_signals.get(signal_name).disconnect(listener)
@pytest.mark.asyncio
async def test_trigger_run_with_no_event_name_logs_and_returns():
class MyTrigger(Trigger):
def make_event_object(self, state):
return None
t = MyTrigger()
# event_name unset / empty
t.set_property("event_name", "")
t.set_property("trigger", "x")
g = Graph()
entry = Entry()
sink = CaptureNode()
g.add_node(entry)
g.add_node(t)
g.add_node(sink)
g.connect(entry.outputs[0], t.get_input_socket("trigger"))
g.connect(t.get_output_socket("event"), sink.inputs[0])
# Should not raise — Trigger.run returns early due to missing event_name.
await g.execute()
await cleanup_pending_tasks()
# No event was emitted onto the output socket
assert sink.captured in ([], [UNRESOLVED])
@pytest.mark.asyncio
async def test_trigger_run_with_unknown_signal_returns_silently():
class MyTrigger(Trigger):
def make_event_object(self, state):
return None
t = MyTrigger()
t.set_property("event_name", "this_signal_does_not_exist")
t.set_property("trigger", "x")
g = Graph()
entry = Entry()
sink = CaptureNode()
g.add_node(entry)
g.add_node(t)
g.add_node(sink)
g.connect(entry.outputs[0], t.get_input_socket("trigger"))
g.connect(t.get_output_socket("event"), sink.inputs[0])
# Should not raise — Trigger logs and returns
await g.execute()
await cleanup_pending_tasks()
def test_trigger_make_event_object_default_raises():
t = Trigger()
with pytest.raises(NotImplementedError):
t.make_event_object(None)
# ---------------------------------------------------------------------------
# validate_node WrapValidator
# ---------------------------------------------------------------------------
def _identity_handler(v):
return v
def test_validate_node_returns_existing_nodebase():
n = Node(title="N")
info = type("I", (), {})()
out = validate_node(n, _identity_handler, info)
assert out is n
def test_validate_node_raises_on_unrecognised_dict():
info = type("I", (), {})()
with pytest.raises(ValueError, match="Could not validate"):
validate_node({"foo": "bar"}, _identity_handler, info)
# ---------------------------------------------------------------------------
# check_is_available — additional deactivation paths
# ---------------------------------------------------------------------------
def test_check_is_available_returns_false_when_required_input_missing():
"""A node with an unresolved required input gets all of its outputs
deactivated and reports unavailable."""
n = Node(title="N")
n.add_input("req") # no source, no property -> UNRESOLVED
n.add_output("o")
g = Graph()
g.add_node(n)
with GraphContext() as state:
assert n.check_is_available(state) is False
assert n.outputs[0].deactivated is True
def test_check_is_available_endpoint_node_only_needs_inputs():
"""A node with only inputs (an "endpoint") is available as long as its
required inputs are satisfied."""
n = Node(title="endpoint")
n.add_input("v")
n.set_property("v", "x") # property satisfies the input
g = Graph()
g.add_node(n)
with GraphContext() as state:
assert n.check_is_available(state) is True
def test_check_is_available_grouped_inputs_one_satisfied_is_enough():
"""Grouped inputs: at least one must be available — none being available
deactivates the node."""
n = Node(title="grouped")
n.add_input("a", group="g1")
n.add_input("b", group="g1")
n.add_output("o")
g = Graph()
g.add_node(n)
# Neither set -> unavailable
with GraphContext() as state:
assert n.check_is_available(state) is False
# Property on `a` -> available
n.set_property("a", "x")
n.outputs[0].deactivated = False
with GraphContext() as state:
# Need a non-deactivated downstream output to satisfy the path check.
# Add a downstream consumer.
downstream = Node(title="downstream")
downstream.add_input("v")
g.add_node(downstream)
g.connect(n.outputs[0], downstream.inputs[0])
assert n.check_is_available(state) is True
def test_check_is_available_isolated_node_returns_false():
"""_isolated nodes report unavailable to opt out of normal dispatch."""
class IsolatedThing(Node):
_isolated: ClassVar[bool] = True
def setup(self):
self.add_output("v")
n = IsolatedThing()
g = Graph()
g.add_node(n)
with GraphContext() as state:
assert n.check_is_available(state) is False
# ---------------------------------------------------------------------------
# Stage node default behaviour
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stage_node_unconnected_state_input_defaults_true():
"""An unconnected `state` input on a Stage node propagates True."""
captured = []
class Recorder(Node):
def setup(self):
self.add_input("v")
async def run(self, state):
captured.append(self.get_input_value("v"))
g = Graph()
s = Stage()
s.set_property("stage", 0)
rec = Recorder(title="rec")
g.add_node(s)
g.add_node(rec)
g.connect(s.outputs[0], rec.inputs[0])
await g.execute()
await cleanup_pending_tasks()
assert captured == [True]
# ---------------------------------------------------------------------------
# Watch / Route / Null / Entry pass-through
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_watch_passes_value_through():
g = Graph()
src = ValueNode(value="watched", title="src")
w = Watch(title="watcher")
sink = CaptureNode(title="sink")
g.add_node(src)
g.add_node(w)
g.add_node(sink)
g.connect(src.outputs[0], w.inputs[0])
g.connect(w.outputs[0], sink.inputs[0])
await g.execute()
await cleanup_pending_tasks()
assert sink.captured == ["watched"]
@pytest.mark.asyncio
async def test_route_passes_value_through():
g = Graph()
src = ValueNode(value=999, title="src")
r = Route()
sink = CaptureNode(title="sink")
g.add_node(src)
g.add_node(r)
g.add_node(sink)
g.connect(src.outputs[0], r.inputs[0])
g.connect(r.outputs[0], sink.inputs[0])
await g.execute()
await cleanup_pending_tasks()
assert sink.captured == [999]
# ---------------------------------------------------------------------------
# Group / Comment
# ---------------------------------------------------------------------------
def test_group_and_comment_defaults():
g = Group(title="hello")
assert g.title == "hello"
assert g.inherited is False
c = Comment(text="note")
assert c.text == "note"
assert c.inherited is False
# ---------------------------------------------------------------------------
# Creative-mode node-state tracking
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_node_state_push_pop_returns_node_state_in_creative_mode():
"""When state.shared['creative_mode'] is True, push/pop construct and
return NodeState snapshots (the stack itself is debounced/flushed by
signal_note_state)."""
class TouchNode(Node):
def setup(self):
self.add_input("trigger", optional=True)
async def run(self, state):
pass
g = Graph()
n = TouchNode(title="touch", id="touch-id")
g.add_node(n)
outer = GraphState()
outer.shared["creative_mode"] = True
pushed = await g.node_state_push(n, outer)
assert pushed is not None
assert pushed.node_id == "touch-id"
popped = await g.node_state_pop(pushed, n, outer)
assert popped is not None
# pop() always sets end_time
assert popped.end_time is not None
assert popped.node_id == "touch-id"
# pop with error string surfaces the error on the snapshot
popped_err = await g.node_state_pop(pushed, n, outer, error="boom-trace")
assert popped_err.error == "boom-trace"
@pytest.mark.asyncio
async def test_node_state_push_pop_noop_outside_creative_mode():
"""Push and pop short-circuit when creative_mode is not set."""
g = Graph()
n = Node(title="N")
state = GraphState() # default: shared is empty
pushed = await g.node_state_push(n, state)
assert pushed is None
popped = await g.node_state_pop(pushed, n, state)
assert popped is None
assert state.stack == []
@pytest.mark.asyncio
async def test_node_state_push_inactive_marks_node_state_deactivated():
n = Node(title="N")
g = Graph()
g.add_node(n)
outer = GraphState()
outer.shared["creative_mode"] = True
pushed = await g.node_state_push(n, outer, inactive=True)
assert pushed.deactivated is True
# ---------------------------------------------------------------------------
# Loop on_loop_start / on_loop_end / on_loop_error subclass hooks
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_loop_subclass_lifecycle_hooks_invoked():
"""Subclass-overridable hooks fire in the right order each iteration."""
events = []
class LifecycleLoop(Loop):
async def on_loop_start(self, state):
events.append("start")
async def on_loop_end(self, state):
events.append("end")
class StopAfterFirst(Node):
ran: bool = False
def setup(self):
self.add_input("trigger")
async def run(self, state):
self.ran = True
raise LoopExit()
entry = Entry()
body = StopAfterFirst(title="body")
loop = LifecycleLoop()
loop.add_node(entry)
loop.add_node(body)
loop.connect(entry.outputs[0], body.inputs[0])
outer = Graph()
outer_entry = Entry(title="outer")
outer.add_node(outer_entry)
outer.add_node(loop)
outer.connect(outer_entry.outputs[0], loop.inputs[0])
await outer.execute()
await cleanup_pending_tasks()
# on_loop_start fires before any chains; LoopExit returns inside the
# node loop so on_loop_end does NOT fire on that iteration.
assert events == ["start"]
assert body.ran is True
@pytest.mark.asyncio
async def test_loop_on_loop_error_invoked_on_exception():
"""When a body raises a non-control exception the loop's
handle_error and on_loop_error callbacks both fire and the loop
keeps iterating until LoopExit."""
handled = []
on_loop_errors = []
iteration = {"n": 0}
class ErrorThenExit(Node):
def setup(self):
self.add_input("trigger")
async def run(self, state):
iteration["n"] += 1
if iteration["n"] == 1:
raise RuntimeError("first-iteration-fail")
raise LoopExit()
class TrackingLoop(Loop):
sleep: float = 0.0 # avoid the 1-second post-error sleep
async def on_loop_error(self, state, exc):
on_loop_errors.append(type(exc).__name__)
async def handler(state, exc):
handled.append(type(exc).__name__)
entry = Entry()
body = ErrorThenExit(title="body")
loop = TrackingLoop()
# Patch the per-iteration sleep delay introduced by handle_error so the
# test doesn't take a full second.
loop.sleep = 0.0
loop.error_handlers.append(handler)
loop.add_node(entry)
loop.add_node(body)
loop.connect(entry.outputs[0], body.inputs[0])
outer_entry = Entry(title="outer")
outer = Graph()
outer.add_node(outer_entry)
outer.add_node(loop)
outer.connect(outer_entry.outputs[0], loop.inputs[0])
# Don't actually wait the 1s asyncio.sleep in the loop's exception
# handler — we just want to verify the hooks fire.
import asyncio as _asyncio
real_sleep = _asyncio.sleep
async def fast_sleep(_):
await real_sleep(0)
_asyncio.sleep = fast_sleep
try:
await outer.execute()
finally:
_asyncio.sleep = real_sleep
await cleanup_pending_tasks()
assert handled == ["RuntimeError"]
assert on_loop_errors == ["RuntimeError"]
assert iteration["n"] == 2 # second iteration ran and raised LoopExit
# ---------------------------------------------------------------------------
# Verbose state mode
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_graph_executes_with_verbose_state(caplog):
"""Bumping verbosity to VERBOSE exercises the verbose-only log paths."""
g, src, sink = build_value_capture_graph("v")
async def set_verbose(state):
state.verbosity = NodeVerbosity.VERBOSE
# Pre-callback to set verbosity then execute. Easiest: subclass execute.
# Instead, install a pre-execute hack via state_values is not possible —
# just exercise via the graph callbacks list (assertions: no exception).
g.callbacks.append(set_verbose)
await g.execute()
await cleanup_pending_tasks()
assert sink.captured == ["v"]
# ---------------------------------------------------------------------------
# Loop initial cycle protection (through outer Graph.execute)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_loop_run_method_calls_execute():
"""Loop is itself a Node; when used as a sub-node Graph._execute_inner
invokes Loop.run which delegates to Loop.execute."""
runs = []
class StopOnce(Node):
def setup(self):
self.add_input("trigger")
async def run(self, state):
runs.append(1)
raise LoopExit()
inner_entry = Entry()
body = StopOnce(title="body")
loop = Loop()
loop.add_node(inner_entry)
loop.add_node(body)
loop.connect(inner_entry.outputs[0], body.inputs[0])
outer = Graph()
outer_entry = Entry(title="outer-entry")
outer.add_node(outer_entry)
outer.add_node(loop)
outer.connect(outer_entry.outputs[0], loop.inputs[0])
await outer.execute()
await cleanup_pending_tasks()
assert runs == [1]
# ---------------------------------------------------------------------------
# Trigger.after hook called after signal send
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_trigger_after_hook_called_with_event():
"""Trigger.after runs after the signal is dispatched."""
after_calls = []
class MyTrigger(Trigger):
def make_event_object(self, state):
return "evt"
async def after(self, state, event):
after_calls.append(event)
import talemate.emit.async_signals as async_signals
signal_name = "test_trigger_after_signal"
async_signals.register(signal_name)
t = MyTrigger()
t.set_property("event_name", signal_name)
t.set_property("trigger", "x")
g = Graph()
entry = Entry()
sink = CaptureNode()
g.add_node(entry)
g.add_node(t)
g.add_node(sink)
g.connect(entry.outputs[0], t.get_input_socket("trigger"))
g.connect(t.get_output_socket("event"), sink.inputs[0])
await g.execute()
await cleanup_pending_tasks()
assert after_calls == ["evt"]