diff --git a/src/talemate/agents/visual/nodes.py b/src/talemate/agents/visual/nodes.py index fb0ad226..7eaff8e6 100644 --- a/src/talemate/agents/visual/nodes.py +++ b/src/talemate/agents/visual/nodes.py @@ -287,6 +287,10 @@ class UnpackPrompt(Node): self.add_output("instructions", socket_type="str") self.add_output("positive_prompt", socket_type="str") self.add_output("negative_prompt", socket_type="str") + self.add_output("positive_prompt_keywords", socket_type="str") + self.add_output("negative_prompt_keywords", socket_type="str") + self.add_output("positive_prompt_descriptive", socket_type="str") + self.add_output("negative_prompt_descriptive", socket_type="str") async def run(self, state: GraphState): prompt = self.normalized_input_value("prompt") @@ -299,6 +303,10 @@ class UnpackPrompt(Node): "parts": parts, "positive_prompt": prompt.positive_prompt, "negative_prompt": prompt.negative_prompt, + "positive_prompt_keywords": prompt.positive_prompt_keywords, + "negative_prompt_keywords": prompt.negative_prompt_keywords, + "positive_prompt_descriptive": prompt.positive_prompt_descriptive, + "negative_prompt_descriptive": prompt.negative_prompt_descriptive, } ) diff --git a/src/talemate/agents/visual/schema.py b/src/talemate/agents/visual/schema.py index 9923a370..f878ddb8 100644 --- a/src/talemate/agents/visual/schema.py +++ b/src/talemate/agents/visual/schema.py @@ -133,34 +133,42 @@ class VisualPrompt(pydantic.BaseModel): @pydantic.computed_field @property def positive_prompt(self) -> str: - prompt: list[str] = [] - if self.prompt_type == PROMPT_TYPE.KEYWORDS: - for part in self.parts: - prompt.extend(part.positive_keywords) - return ", ".join(dict.fromkeys(prompt)) - elif self.prompt_type == PROMPT_TYPE.DESCRIPTIVE: - for part in self.parts: - if part.positive_descriptive: - prompt.append(part.positive_descriptive) - return "\n\n".join(prompt) - return "" + return self._build_prompt(self.prompt_type, True) @pydantic.computed_field @property def negative_prompt(self) -> str: + return self._build_prompt(self.prompt_type, False) + + @property + def positive_prompt_keywords(self) -> str: + return self._build_prompt(PROMPT_TYPE.KEYWORDS, True) + + @property + def negative_prompt_keywords(self) -> str: + return self._build_prompt(PROMPT_TYPE.KEYWORDS, False) + + @property + def positive_prompt_descriptive(self) -> str: + return self._build_prompt(PROMPT_TYPE.DESCRIPTIVE, True) + + @property + def negative_prompt_descriptive(self) -> str: + return self._build_prompt(PROMPT_TYPE.DESCRIPTIVE, False) + + def _build_prompt(self, prompt_type: PROMPT_TYPE, positive: bool) -> str: prompt: list[str] = [] - if self.prompt_type == PROMPT_TYPE.KEYWORDS: + if prompt_type == PROMPT_TYPE.KEYWORDS: for part in self.parts: - prompt.extend(part.negative_keywords) + prompt.extend(part.positive_keywords if positive else part.negative_keywords) return ", ".join(dict.fromkeys(prompt)) - elif self.prompt_type == PROMPT_TYPE.DESCRIPTIVE: + elif prompt_type == PROMPT_TYPE.DESCRIPTIVE: for part in self.parts: - if part.negative_descriptive: - prompt.append(part.negative_descriptive) + if part.positive_descriptive if positive else part.negative_descriptive: + prompt.append(part.positive_descriptive if positive else part.negative_descriptive) return "\n\n".join(prompt) return "" - class BackendStatus(pydantic.BaseModel): type: BackendStatusType message: str | None = None