diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 9edebcce..63a0c781 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -27,6 +27,7 @@ from .control import ConversationControl from .memory import MemoryPage from .rag_setting import RAGSetting from .report import ReportIssue +from .verification import VerificationPage DEFAULT_SETTING = "(default)" INFO_PANEL_SCALES = {True: 8, False: 4} @@ -102,6 +103,7 @@ class ChatPage(BasePage): # long-term memory page self.long_term_memory_btn = gr.Button("Long-term Memory") + self.verify_answer_btn = gr.Button("Verify Answer") if len(self._app.index_manager.indices) > 0: with gr.Accordion(label="Quick Upload", open=False) as _: @@ -156,6 +158,8 @@ class ChatPage(BasePage): ) as self.memory_accordion: self.memory_page = MemoryPage(self._app) + self.verification_page = VerificationPage(self._app) + with gr.Accordion(label="Information panel", open=True): self.modal = gr.HTML("
") self.plot_panel = gr.Plot(visible=False) @@ -462,6 +466,18 @@ class ChatPage(BasePage): outputs=[self.memory_accordion, self.memory_page.memory_detail_panel], ) + self.verify_answer_btn.click( + self.verification_page.verify_answer, + inputs=[ + self.chat_panel.chatbot, + self.original_retrieval_history, + ], + outputs=[ + self.verification_page.verification_ui, + self.verification_page.verification_result, + ], + ) + # evidence display on message selection self.chat_panel.chatbot.select( self.message_selected, diff --git a/libs/ktem/ktem/pages/chat/verification.py b/libs/ktem/ktem/pages/chat/verification.py new file mode 100644 index 00000000..aea3c470 --- /dev/null +++ b/libs/ktem/ktem/pages/chat/verification.py @@ -0,0 +1,94 @@ +import logging + +import gradio as gr +from ktem.app import BasePage +from ktem.reasoning.prompt_optimization.verify_answer import ( + verify_answer_groundedness_azure, +) +from ktem.reasoning.simple import AnswerWithContextPipeline +from ktem.utils.render import Render + +from kotaemon.base import Document + +logger = logging.getLogger(__name__) + + +class VerificationPage(BasePage): + """Verify the groundedness of the answer""" + + def __init__(self, app): + self._app = app + + self.on_building_ui() + + def on_building_ui(self): + with gr.Accordion( + "Verification Result", + visible=False, + ) as self.verification_ui: + self.verification_result = gr.HTML() + self.close_button = gr.Button( + "Close", + variant="secondary", + ) + + def on_register_events(self): + self.close_button.click( + fn=lambda: gr.update(visible=False), + outputs=[self.verification_ui], + ) + + def highlight_spans(self, text, spans): + spans = sorted(spans, key=lambda x: x["start"]) + highlighted_text = text[: spans[0]["start"]] + for idx, span in enumerate(spans): + to_highlight = text[span["start"] : span["end"]] + highlighted_text += Render.highlight(to_highlight) + if idx < len(spans) - 1: + highlighted_text += text[span["end"] : spans[idx + 1]["start"]] + highlighted_text += text[spans[-1]["end"] :] + + return highlighted_text + + def verify_answer(self, chat_history, retrieval_history): + if len(chat_history) < 1: + raise gr.Error("Empty chat.") + + query = chat_history[-1][0] + answer = chat_history[-1][1] + + last_evidence = retrieval_history[-1] + text_only_evidence, _ = AnswerWithContextPipeline.extract_evidence_images( + last_evidence + ) + + gr.Info("Verifying the groundedness of the answer. Please wait...") + result = verify_answer_groundedness_azure(query, answer, [text_only_evidence]) + + verification_output = "

Trust score: {:.2f}

".format( + 1 - result["ungroundedPercentage"] + ) + verification_output += "

Claims that might be incorrect

" + spans = [ + { + "start": claim["offset"]["codePoint"], + "end": claim["offset"]["codePoint"] + claim["length"]["codePoint"], + } + for claim in result["ungroundedDetails"] + ] + highlighted_text = self.highlight_spans(answer, spans) + highlighted_text = highlighted_text.replace("\n", "
") + verification_output += f"
{highlighted_text}
" + + verification_output += "

Rationale

" + print(verification_output) + rationale = "" + + for claim in result["ungroundedDetails"]: + rationale += Render.collapsible_with_header( + Document(text=claim["reason"], metadata={"file_name": claim["text"]}) + ) + + verification_output += f"
{rationale}
" + + return gr.update(visible=True), verification_output diff --git a/libs/ktem/ktem/reasoning/prompt_optimization/verify_answer.py b/libs/ktem/ktem/reasoning/prompt_optimization/verify_answer.py new file mode 100644 index 00000000..27f9e8c4 --- /dev/null +++ b/libs/ktem/ktem/reasoning/prompt_optimization/verify_answer.py @@ -0,0 +1,48 @@ +import http.client +import json +from typing import Any + +from decouple import config + +AZURE_CONTENT_SAFETY_ENDPOINT = config("AZURE_CONTENT_SAFETY_ENDPOINT", default="") +AZURE_CONTENT_SAFETY_KEY = config("AZURE_CONTENT_SAFETY_KEY", default="") +AZURE_OPENAI_ENDPOINT_REASONING = config("AZURE_OPENAI_ENDPOINT_REASONING", default="") +AZURE_OPENAI_DEPLOYMENT_NAME_REASONING = config( + "AZURE_OPENAI_DEPLOYMENT_NAME_REASONING", default="" +) + + +def verify_answer_groundedness_azure( + query: str, + answer: str, + docs: list[str], +) -> dict[str, Any]: + conn = http.client.HTTPSConnection(AZURE_CONTENT_SAFETY_ENDPOINT) + payload = json.dumps( + { + "domain": "Generic", + "task": "QnA", + "qna": {"query": query}, + "text": answer, + "groundingSources": docs, + "reasoning": True, + "llmResource": { + "resourceType": "AzureOpenAI", + "azureOpenAIEndpoint": AZURE_OPENAI_ENDPOINT_REASONING, + "azureOpenAIDeploymentName": AZURE_OPENAI_DEPLOYMENT_NAME_REASONING, + }, + } + ) + headers = { + "Ocp-Apim-Subscription-Key": AZURE_CONTENT_SAFETY_KEY, + "Content-Type": "application/json", + } + conn.request( + "POST", + "/contentsafety/text:detectGroundedness?api-version=2024-09-15-preview", + payload, + headers, + ) + res = conn.getresponse() + data = res.read().decode("utf-8") + return json.loads(data) diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 001d7bf1..65e921ed 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -456,9 +456,11 @@ class AnswerWithContextPipeline(BaseComponent): return answer - def extract_evidence_images(self, evidence: str): + @staticmethod + def extract_evidence_images(evidence: str): """Util function to extract and isolate images from context/evidence""" - image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" + image_pattern = r"(data:image\/[^;]+;base64[^']+)" + matches = re.findall(image_pattern, evidence) context = re.sub(image_pattern, "", evidence) print(f"Got {len(matches)} images")