feat: add knet feedback

This commit is contained in:
trducng
2024-10-08 07:52:32 +00:00
parent 485c6a5510
commit 351ec7fd16
4 changed files with 41 additions and 11 deletions

View File

@@ -315,6 +315,8 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
def run(
self,
text: str,
*args,
**kwargs,
) -> list[RetrievedDocument]:
if not self.file_ids:
return []

View File

@@ -70,6 +70,8 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
doc_ids: list of document ids to constraint the retrieval
"""
print("searching in doc_ids", doc_ids)
conv_id = kwargs.get("conv_id", "")
if not doc_ids:
return []
@@ -78,6 +80,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
"query": text,
"collection": self.collection_name,
"meta_filters": {"doc_name": doc_ids},
"conv_id": conv_id,
}
params["meta_filters"] = json.dumps(params["meta_filters"])
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)

View File

@@ -1,6 +1,9 @@
import json
import os
from typing import Optional
import gradio as gr
import requests
from ktem.app import BasePage
from ktem.db.models import IssueReport, engine
from sqlmodel import Session
@@ -8,6 +11,10 @@ from sqlmodel import Session
class ReportIssue(BasePage):
def __init__(self, app):
self.knet_endpoint = (
os.environ.get("KN_ENDPOINT", "http://127.0.0.1:8081") + "/feedback"
)
self._app = app
self.on_building_ui()
@@ -64,13 +71,14 @@ class ReportIssue(BasePage):
else:
print(f"Unknown selector type: {index.selector}")
issue_dict = {
"correctness": correctness,
"issues": issues,
"more_detail": more_detail,
}
with Session(engine) as session:
issue = IssueReport(
issues={
"correctness": correctness,
"issues": issues,
"more_detail": more_detail,
},
issues=issue_dict,
chat={
"conv_id": conv_id,
"chat_history": chat_history,
@@ -83,4 +91,18 @@ class ReportIssue(BasePage):
)
session.add(issue)
session.commit()
# forward feedback to KNet service
try:
data = {
"feedback": json.dumps(issue_dict),
"conv_id": conv_id,
}
print(data)
response = requests.post(self.knet_endpoint, data=data)
response.raise_for_status()
print(response.text)
except Exception as e:
print("Error submitting Knet feedback:", e)
gr.Info("Thank you for your feedback")

View File

@@ -516,7 +516,10 @@ class FullQAPipeline(BaseReasoning):
use_rewrite: bool = False
def retrieve(
self, message: str, history: list
self,
message: str,
history: list,
conv_id: str | None = None,
) -> tuple[list[RetrievedDocument], list[Document]]:
"""Retrieve the documents based on the message"""
# if len(message) < self.trigger_context:
@@ -539,7 +542,7 @@ class FullQAPipeline(BaseReasoning):
for idx, retriever in enumerate(self.retrievers):
retriever_node = self._prepare_child(retriever, f"retriever_{idx}")
retriever_docs = retriever_node(text=query)
retriever_docs = retriever_node(text=query, conv_id=conv_id)
retriever_docs_text = []
retriever_docs_plot = []
@@ -701,7 +704,7 @@ class FullQAPipeline(BaseReasoning):
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
docs, infos = self.retrieve(message, history)
docs, infos = self.retrieve(message, history, conv_id=conv_id)
for _ in infos:
self.report_output(_)
await asyncio.sleep(0.1)
@@ -741,7 +744,7 @@ class FullQAPipeline(BaseReasoning):
print(f"Retrievers {self.retrievers}")
# should populate the context
docs, infos = self.retrieve(message, history)
docs, infos = self.retrieve(message, history, conv_id=conv_id)
print(f"Got {len(docs)} retrieved documents")
yield from infos
@@ -896,7 +899,7 @@ class FullDecomposeQAPipeline(FullQAPipeline):
f"<br>{message}<br><b>Answer</b><br>",
)
# should populate the context
docs, infos = self.retrieve(message, history)
docs, infos = self.retrieve(message, history, conv_id=conv_id)
print(f"Got {len(docs)} retrieved documents")
yield from infos
@@ -946,7 +949,7 @@ class FullDecomposeQAPipeline(FullQAPipeline):
)
# should populate the context
docs, infos = self.retrieve(message, history)
docs, infos = self.retrieve(message, history, conv_id=conv_id)
print(f"Got {len(docs)} retrieved documents")
yield from infos