mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2026-02-23 19:49:37 +01:00
feat: add experimental verification API
This commit is contained in:
@@ -27,6 +27,7 @@ from .control import ConversationControl
|
||||
from .memory import MemoryPage
|
||||
from .rag_setting import RAGSetting
|
||||
from .report import ReportIssue
|
||||
from .verification import VerificationPage
|
||||
|
||||
DEFAULT_SETTING = "(default)"
|
||||
INFO_PANEL_SCALES = {True: 8, False: 4}
|
||||
@@ -102,6 +103,7 @@ class ChatPage(BasePage):
|
||||
|
||||
# long-term memory page
|
||||
self.long_term_memory_btn = gr.Button("Long-term Memory")
|
||||
self.verify_answer_btn = gr.Button("Verify Answer")
|
||||
|
||||
if len(self._app.index_manager.indices) > 0:
|
||||
with gr.Accordion(label="Quick Upload", open=False) as _:
|
||||
@@ -156,6 +158,8 @@ class ChatPage(BasePage):
|
||||
) as self.memory_accordion:
|
||||
self.memory_page = MemoryPage(self._app)
|
||||
|
||||
self.verification_page = VerificationPage(self._app)
|
||||
|
||||
with gr.Accordion(label="Information panel", open=True):
|
||||
self.modal = gr.HTML("<div id='pdf-modal'></div>")
|
||||
self.plot_panel = gr.Plot(visible=False)
|
||||
@@ -462,6 +466,18 @@ class ChatPage(BasePage):
|
||||
outputs=[self.memory_accordion, self.memory_page.memory_detail_panel],
|
||||
)
|
||||
|
||||
self.verify_answer_btn.click(
|
||||
self.verification_page.verify_answer,
|
||||
inputs=[
|
||||
self.chat_panel.chatbot,
|
||||
self.original_retrieval_history,
|
||||
],
|
||||
outputs=[
|
||||
self.verification_page.verification_ui,
|
||||
self.verification_page.verification_result,
|
||||
],
|
||||
)
|
||||
|
||||
# evidence display on message selection
|
||||
self.chat_panel.chatbot.select(
|
||||
self.message_selected,
|
||||
|
||||
94
libs/ktem/ktem/pages/chat/verification.py
Normal file
94
libs/ktem/ktem/pages/chat/verification.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.reasoning.prompt_optimization.verify_answer import (
|
||||
verify_answer_groundedness_azure,
|
||||
)
|
||||
from ktem.reasoning.simple import AnswerWithContextPipeline
|
||||
from ktem.utils.render import Render
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VerificationPage(BasePage):
|
||||
"""Verify the groundedness of the answer"""
|
||||
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(
|
||||
"Verification Result",
|
||||
visible=False,
|
||||
) as self.verification_ui:
|
||||
self.verification_result = gr.HTML()
|
||||
self.close_button = gr.Button(
|
||||
"Close",
|
||||
variant="secondary",
|
||||
)
|
||||
|
||||
def on_register_events(self):
|
||||
self.close_button.click(
|
||||
fn=lambda: gr.update(visible=False),
|
||||
outputs=[self.verification_ui],
|
||||
)
|
||||
|
||||
def highlight_spans(self, text, spans):
|
||||
spans = sorted(spans, key=lambda x: x["start"])
|
||||
highlighted_text = text[: spans[0]["start"]]
|
||||
for idx, span in enumerate(spans):
|
||||
to_highlight = text[span["start"] : span["end"]]
|
||||
highlighted_text += Render.highlight(to_highlight)
|
||||
if idx < len(spans) - 1:
|
||||
highlighted_text += text[span["end"] : spans[idx + 1]["start"]]
|
||||
highlighted_text += text[spans[-1]["end"] :]
|
||||
|
||||
return highlighted_text
|
||||
|
||||
def verify_answer(self, chat_history, retrieval_history):
|
||||
if len(chat_history) < 1:
|
||||
raise gr.Error("Empty chat.")
|
||||
|
||||
query = chat_history[-1][0]
|
||||
answer = chat_history[-1][1]
|
||||
|
||||
last_evidence = retrieval_history[-1]
|
||||
text_only_evidence, _ = AnswerWithContextPipeline.extract_evidence_images(
|
||||
last_evidence
|
||||
)
|
||||
|
||||
gr.Info("Verifying the groundedness of the answer. Please wait...")
|
||||
result = verify_answer_groundedness_azure(query, answer, [text_only_evidence])
|
||||
|
||||
verification_output = "<h4>Trust score: {:.2f}</h4>".format(
|
||||
1 - result["ungroundedPercentage"]
|
||||
)
|
||||
verification_output += "<h4>Claims that might be incorrect</h4>"
|
||||
spans = [
|
||||
{
|
||||
"start": claim["offset"]["codePoint"],
|
||||
"end": claim["offset"]["codePoint"] + claim["length"]["codePoint"],
|
||||
}
|
||||
for claim in result["ungroundedDetails"]
|
||||
]
|
||||
highlighted_text = self.highlight_spans(answer, spans)
|
||||
highlighted_text = highlighted_text.replace("\n", "<br>")
|
||||
verification_output += f"<div>{highlighted_text}</div>"
|
||||
|
||||
verification_output += "<h4>Rationale</h4>"
|
||||
print(verification_output)
|
||||
rationale = ""
|
||||
|
||||
for claim in result["ungroundedDetails"]:
|
||||
rationale += Render.collapsible_with_header(
|
||||
Document(text=claim["reason"], metadata={"file_name": claim["text"]})
|
||||
)
|
||||
|
||||
verification_output += f"<div><b>{rationale}</b></div>"
|
||||
|
||||
return gr.update(visible=True), verification_output
|
||||
@@ -0,0 +1,48 @@
|
||||
import http.client
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from decouple import config
|
||||
|
||||
AZURE_CONTENT_SAFETY_ENDPOINT = config("AZURE_CONTENT_SAFETY_ENDPOINT", default="")
|
||||
AZURE_CONTENT_SAFETY_KEY = config("AZURE_CONTENT_SAFETY_KEY", default="")
|
||||
AZURE_OPENAI_ENDPOINT_REASONING = config("AZURE_OPENAI_ENDPOINT_REASONING", default="")
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME_REASONING = config(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME_REASONING", default=""
|
||||
)
|
||||
|
||||
|
||||
def verify_answer_groundedness_azure(
|
||||
query: str,
|
||||
answer: str,
|
||||
docs: list[str],
|
||||
) -> dict[str, Any]:
|
||||
conn = http.client.HTTPSConnection(AZURE_CONTENT_SAFETY_ENDPOINT)
|
||||
payload = json.dumps(
|
||||
{
|
||||
"domain": "Generic",
|
||||
"task": "QnA",
|
||||
"qna": {"query": query},
|
||||
"text": answer,
|
||||
"groundingSources": docs,
|
||||
"reasoning": True,
|
||||
"llmResource": {
|
||||
"resourceType": "AzureOpenAI",
|
||||
"azureOpenAIEndpoint": AZURE_OPENAI_ENDPOINT_REASONING,
|
||||
"azureOpenAIDeploymentName": AZURE_OPENAI_DEPLOYMENT_NAME_REASONING,
|
||||
},
|
||||
}
|
||||
)
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": AZURE_CONTENT_SAFETY_KEY,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
conn.request(
|
||||
"POST",
|
||||
"/contentsafety/text:detectGroundedness?api-version=2024-09-15-preview",
|
||||
payload,
|
||||
headers,
|
||||
)
|
||||
res = conn.getresponse()
|
||||
data = res.read().decode("utf-8")
|
||||
return json.loads(data)
|
||||
@@ -456,9 +456,11 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
|
||||
return answer
|
||||
|
||||
def extract_evidence_images(self, evidence: str):
|
||||
@staticmethod
|
||||
def extract_evidence_images(evidence: str):
|
||||
"""Util function to extract and isolate images from context/evidence"""
|
||||
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||
image_pattern = r"(data:image\/[^;]+;base64[^']+)"
|
||||
|
||||
matches = re.findall(image_pattern, evidence)
|
||||
context = re.sub(image_pattern, "", evidence)
|
||||
print(f"Got {len(matches)} images")
|
||||
|
||||
Reference in New Issue
Block a user