feat: add experimental verification API

This commit is contained in:
trducng
2024-12-09 08:20:00 +00:00
parent 80fea27855
commit 53eea18511
4 changed files with 162 additions and 2 deletions

View File

@@ -27,6 +27,7 @@ from .control import ConversationControl
from .memory import MemoryPage
from .rag_setting import RAGSetting
from .report import ReportIssue
from .verification import VerificationPage
DEFAULT_SETTING = "(default)"
INFO_PANEL_SCALES = {True: 8, False: 4}
@@ -102,6 +103,7 @@ class ChatPage(BasePage):
# long-term memory page
self.long_term_memory_btn = gr.Button("Long-term Memory")
self.verify_answer_btn = gr.Button("Verify Answer")
if len(self._app.index_manager.indices) > 0:
with gr.Accordion(label="Quick Upload", open=False) as _:
@@ -156,6 +158,8 @@ class ChatPage(BasePage):
) as self.memory_accordion:
self.memory_page = MemoryPage(self._app)
self.verification_page = VerificationPage(self._app)
with gr.Accordion(label="Information panel", open=True):
self.modal = gr.HTML("<div id='pdf-modal'></div>")
self.plot_panel = gr.Plot(visible=False)
@@ -462,6 +466,18 @@ class ChatPage(BasePage):
outputs=[self.memory_accordion, self.memory_page.memory_detail_panel],
)
self.verify_answer_btn.click(
self.verification_page.verify_answer,
inputs=[
self.chat_panel.chatbot,
self.original_retrieval_history,
],
outputs=[
self.verification_page.verification_ui,
self.verification_page.verification_result,
],
)
# evidence display on message selection
self.chat_panel.chatbot.select(
self.message_selected,

View File

@@ -0,0 +1,94 @@
import logging
import gradio as gr
from ktem.app import BasePage
from ktem.reasoning.prompt_optimization.verify_answer import (
verify_answer_groundedness_azure,
)
from ktem.reasoning.simple import AnswerWithContextPipeline
from ktem.utils.render import Render
from kotaemon.base import Document
logger = logging.getLogger(__name__)
class VerificationPage(BasePage):
"""Verify the groundedness of the answer"""
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(
"Verification Result",
visible=False,
) as self.verification_ui:
self.verification_result = gr.HTML()
self.close_button = gr.Button(
"Close",
variant="secondary",
)
def on_register_events(self):
self.close_button.click(
fn=lambda: gr.update(visible=False),
outputs=[self.verification_ui],
)
def highlight_spans(self, text, spans):
spans = sorted(spans, key=lambda x: x["start"])
highlighted_text = text[: spans[0]["start"]]
for idx, span in enumerate(spans):
to_highlight = text[span["start"] : span["end"]]
highlighted_text += Render.highlight(to_highlight)
if idx < len(spans) - 1:
highlighted_text += text[span["end"] : spans[idx + 1]["start"]]
highlighted_text += text[spans[-1]["end"] :]
return highlighted_text
def verify_answer(self, chat_history, retrieval_history):
if len(chat_history) < 1:
raise gr.Error("Empty chat.")
query = chat_history[-1][0]
answer = chat_history[-1][1]
last_evidence = retrieval_history[-1]
text_only_evidence, _ = AnswerWithContextPipeline.extract_evidence_images(
last_evidence
)
gr.Info("Verifying the groundedness of the answer. Please wait...")
result = verify_answer_groundedness_azure(query, answer, [text_only_evidence])
verification_output = "<h4>Trust score: {:.2f}</h4>".format(
1 - result["ungroundedPercentage"]
)
verification_output += "<h4>Claims that might be incorrect</h4>"
spans = [
{
"start": claim["offset"]["codePoint"],
"end": claim["offset"]["codePoint"] + claim["length"]["codePoint"],
}
for claim in result["ungroundedDetails"]
]
highlighted_text = self.highlight_spans(answer, spans)
highlighted_text = highlighted_text.replace("\n", "<br>")
verification_output += f"<div>{highlighted_text}</div>"
verification_output += "<h4>Rationale</h4>"
print(verification_output)
rationale = ""
for claim in result["ungroundedDetails"]:
rationale += Render.collapsible_with_header(
Document(text=claim["reason"], metadata={"file_name": claim["text"]})
)
verification_output += f"<div><b>{rationale}</b></div>"
return gr.update(visible=True), verification_output

View File

@@ -0,0 +1,48 @@
import http.client
import json
from typing import Any
from decouple import config
AZURE_CONTENT_SAFETY_ENDPOINT = config("AZURE_CONTENT_SAFETY_ENDPOINT", default="")
AZURE_CONTENT_SAFETY_KEY = config("AZURE_CONTENT_SAFETY_KEY", default="")
AZURE_OPENAI_ENDPOINT_REASONING = config("AZURE_OPENAI_ENDPOINT_REASONING", default="")
AZURE_OPENAI_DEPLOYMENT_NAME_REASONING = config(
"AZURE_OPENAI_DEPLOYMENT_NAME_REASONING", default=""
)
def verify_answer_groundedness_azure(
query: str,
answer: str,
docs: list[str],
) -> dict[str, Any]:
conn = http.client.HTTPSConnection(AZURE_CONTENT_SAFETY_ENDPOINT)
payload = json.dumps(
{
"domain": "Generic",
"task": "QnA",
"qna": {"query": query},
"text": answer,
"groundingSources": docs,
"reasoning": True,
"llmResource": {
"resourceType": "AzureOpenAI",
"azureOpenAIEndpoint": AZURE_OPENAI_ENDPOINT_REASONING,
"azureOpenAIDeploymentName": AZURE_OPENAI_DEPLOYMENT_NAME_REASONING,
},
}
)
headers = {
"Ocp-Apim-Subscription-Key": AZURE_CONTENT_SAFETY_KEY,
"Content-Type": "application/json",
}
conn.request(
"POST",
"/contentsafety/text:detectGroundedness?api-version=2024-09-15-preview",
payload,
headers,
)
res = conn.getresponse()
data = res.read().decode("utf-8")
return json.loads(data)

View File

@@ -456,9 +456,11 @@ class AnswerWithContextPipeline(BaseComponent):
return answer
def extract_evidence_images(self, evidence: str):
@staticmethod
def extract_evidence_images(evidence: str):
"""Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
image_pattern = r"(data:image\/[^;]+;base64[^']+)"
matches = re.findall(image_pattern, evidence)
context = re.sub(image_pattern, "", evidence)
print(f"Got {len(matches)} images")