mirror of
https://github.com/vegu-ai/talemate.git
synced 2026-05-18 05:05:39 +02:00
211 lines
6.7 KiB
Python
211 lines
6.7 KiB
Python
|
|
"""
|
||
|
|
Unit tests for template_uid tracking in PromptData.
|
||
|
|
|
||
|
|
Tests that when a Prompt is sent via Prompt.send(), the template_uid is correctly
|
||
|
|
passed through to the PromptData emitted via the prompt_sent signal.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock
|
||
|
|
|
||
|
|
from talemate.client.base import PromptData
|
||
|
|
from talemate.prompts.base import Prompt, active_template_uid
|
||
|
|
|
||
|
|
|
||
|
|
class TestPromptDataTemplateUid:
|
||
|
|
"""Tests for the template_uid field in PromptData."""
|
||
|
|
|
||
|
|
def test_prompt_data_has_template_uid_field(self):
|
||
|
|
"""Verify PromptData model includes template_uid field."""
|
||
|
|
data = PromptData(
|
||
|
|
kind="test",
|
||
|
|
prompt="test prompt",
|
||
|
|
response="test response",
|
||
|
|
prompt_tokens=10,
|
||
|
|
response_tokens=5,
|
||
|
|
client_name="test-client",
|
||
|
|
client_type="test",
|
||
|
|
time=1.0,
|
||
|
|
template_uid="narrator.narrate-scene",
|
||
|
|
)
|
||
|
|
assert data.template_uid == "narrator.narrate-scene"
|
||
|
|
|
||
|
|
def test_prompt_data_template_uid_defaults_to_none(self):
|
||
|
|
"""Verify template_uid defaults to None when not provided."""
|
||
|
|
data = PromptData(
|
||
|
|
kind="test",
|
||
|
|
prompt="test prompt",
|
||
|
|
response="test response",
|
||
|
|
prompt_tokens=10,
|
||
|
|
response_tokens=5,
|
||
|
|
client_name="test-client",
|
||
|
|
client_type="test",
|
||
|
|
time=1.0,
|
||
|
|
)
|
||
|
|
assert data.template_uid is None
|
||
|
|
|
||
|
|
def test_prompt_data_serializes_template_uid(self):
|
||
|
|
"""Verify template_uid is included in model_dump() output."""
|
||
|
|
data = PromptData(
|
||
|
|
kind="test",
|
||
|
|
prompt="test prompt",
|
||
|
|
response="test response",
|
||
|
|
prompt_tokens=10,
|
||
|
|
response_tokens=5,
|
||
|
|
client_name="test-client",
|
||
|
|
client_type="test",
|
||
|
|
time=1.0,
|
||
|
|
template_uid="director.guide-scene",
|
||
|
|
)
|
||
|
|
dumped = data.model_dump()
|
||
|
|
assert "template_uid" in dumped
|
||
|
|
assert dumped["template_uid"] == "director.guide-scene"
|
||
|
|
|
||
|
|
|
||
|
|
class TestActiveTemplateUidContext:
|
||
|
|
"""Tests for the active_template_uid context variable."""
|
||
|
|
|
||
|
|
def test_active_template_uid_defaults_to_none(self):
|
||
|
|
"""Verify active_template_uid defaults to None."""
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
def test_active_template_uid_can_be_set_and_reset(self):
|
||
|
|
"""Verify context variable can be set and properly reset."""
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
token = active_template_uid.set("narrator.test")
|
||
|
|
assert active_template_uid.get() == "narrator.test"
|
||
|
|
|
||
|
|
active_template_uid.reset(token)
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
def test_active_template_uid_nested_contexts(self):
|
||
|
|
"""Verify nested context variable handling."""
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
token1 = active_template_uid.set("outer.template")
|
||
|
|
assert active_template_uid.get() == "outer.template"
|
||
|
|
|
||
|
|
token2 = active_template_uid.set("inner.template")
|
||
|
|
assert active_template_uid.get() == "inner.template"
|
||
|
|
|
||
|
|
active_template_uid.reset(token2)
|
||
|
|
assert active_template_uid.get() == "outer.template"
|
||
|
|
|
||
|
|
active_template_uid.reset(token1)
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
|
||
|
|
class TestPromptSendSetsTemplateUid:
|
||
|
|
"""Tests that Prompt.send() properly sets the active_template_uid context."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_client(self):
|
||
|
|
"""Create a mock LLM client."""
|
||
|
|
client = AsyncMock()
|
||
|
|
client.send_prompt = AsyncMock(return_value="test response")
|
||
|
|
client.can_be_coerced = True
|
||
|
|
client.data_format = "json"
|
||
|
|
return client
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_prompt_send_sets_context_variable(self, mock_client):
|
||
|
|
"""Verify Prompt.send() sets active_template_uid during client call."""
|
||
|
|
captured_uid = None
|
||
|
|
|
||
|
|
async def capture_uid(*args, **kwargs):
|
||
|
|
nonlocal captured_uid
|
||
|
|
captured_uid = active_template_uid.get()
|
||
|
|
return "test response"
|
||
|
|
|
||
|
|
mock_client.send_prompt = capture_uid
|
||
|
|
|
||
|
|
prompt = Prompt(
|
||
|
|
uid="narrator.test-template",
|
||
|
|
agent_type="narrator",
|
||
|
|
name="test-template",
|
||
|
|
template="Test prompt content",
|
||
|
|
)
|
||
|
|
|
||
|
|
await prompt.send(mock_client, kind="test")
|
||
|
|
|
||
|
|
assert captured_uid == "narrator.test-template"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_prompt_send_resets_context_after_completion(self, mock_client):
|
||
|
|
"""Verify context is reset after Prompt.send() completes."""
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
prompt = Prompt(
|
||
|
|
uid="narrator.test",
|
||
|
|
agent_type="narrator",
|
||
|
|
name="test",
|
||
|
|
template="Test prompt",
|
||
|
|
)
|
||
|
|
|
||
|
|
await prompt.send(mock_client, kind="test")
|
||
|
|
|
||
|
|
# Context should be reset after send completes
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_prompt_send_resets_context_on_exception(self, mock_client):
|
||
|
|
"""Verify context is reset even if client.send_prompt raises."""
|
||
|
|
mock_client.send_prompt = AsyncMock(side_effect=Exception("Test error"))
|
||
|
|
|
||
|
|
prompt = Prompt(
|
||
|
|
uid="narrator.test",
|
||
|
|
agent_type="narrator",
|
||
|
|
name="test",
|
||
|
|
template="Test prompt",
|
||
|
|
)
|
||
|
|
|
||
|
|
with pytest.raises(Exception, match="Test error"):
|
||
|
|
await prompt.send(mock_client, kind="test")
|
||
|
|
|
||
|
|
# Context should still be reset
|
||
|
|
assert active_template_uid.get() is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_prompt_send_with_empty_uid(self, mock_client):
|
||
|
|
"""Verify empty uid is handled gracefully."""
|
||
|
|
captured_uid = None
|
||
|
|
|
||
|
|
async def capture_uid(*args, **kwargs):
|
||
|
|
nonlocal captured_uid
|
||
|
|
captured_uid = active_template_uid.get()
|
||
|
|
return "test response"
|
||
|
|
|
||
|
|
mock_client.send_prompt = capture_uid
|
||
|
|
|
||
|
|
prompt = Prompt(
|
||
|
|
uid="",
|
||
|
|
agent_type="",
|
||
|
|
name="",
|
||
|
|
template="Test prompt",
|
||
|
|
)
|
||
|
|
|
||
|
|
await prompt.send(mock_client, kind="test")
|
||
|
|
|
||
|
|
# Empty string should be converted to None
|
||
|
|
assert captured_uid is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_prompt_from_text_has_no_uid(self, mock_client):
|
||
|
|
"""Verify Prompt.from_text() results in None template_uid."""
|
||
|
|
captured_uid = None
|
||
|
|
|
||
|
|
async def capture_uid(*args, **kwargs):
|
||
|
|
nonlocal captured_uid
|
||
|
|
captured_uid = active_template_uid.get()
|
||
|
|
return "test response"
|
||
|
|
|
||
|
|
mock_client.send_prompt = capture_uid
|
||
|
|
|
||
|
|
prompt = Prompt.from_text("Some raw prompt text")
|
||
|
|
|
||
|
|
await prompt.send(mock_client, kind="test")
|
||
|
|
|
||
|
|
# from_text creates a prompt with empty uid
|
||
|
|
assert captured_uid is None
|