mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2026-02-23 19:49:37 +01:00
feat: add knet feedback
This commit is contained in:
@@ -315,6 +315,8 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
||||
def run(
|
||||
self,
|
||||
text: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[RetrievedDocument]:
|
||||
if not self.file_ids:
|
||||
return []
|
||||
|
||||
@@ -70,6 +70,8 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
|
||||
doc_ids: list of document ids to constraint the retrieval
|
||||
"""
|
||||
print("searching in doc_ids", doc_ids)
|
||||
conv_id = kwargs.get("conv_id", "")
|
||||
|
||||
if not doc_ids:
|
||||
return []
|
||||
|
||||
@@ -78,6 +80,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
|
||||
"query": text,
|
||||
"collection": self.collection_name,
|
||||
"meta_filters": {"doc_name": doc_ids},
|
||||
"conv_id": conv_id,
|
||||
}
|
||||
params["meta_filters"] = json.dumps(params["meta_filters"])
|
||||
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import IssueReport, engine
|
||||
from sqlmodel import Session
|
||||
@@ -8,6 +11,10 @@ from sqlmodel import Session
|
||||
|
||||
class ReportIssue(BasePage):
|
||||
def __init__(self, app):
|
||||
self.knet_endpoint = (
|
||||
os.environ.get("KN_ENDPOINT", "http://127.0.0.1:8081") + "/feedback"
|
||||
)
|
||||
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
@@ -64,13 +71,14 @@ class ReportIssue(BasePage):
|
||||
else:
|
||||
print(f"Unknown selector type: {index.selector}")
|
||||
|
||||
issue_dict = {
|
||||
"correctness": correctness,
|
||||
"issues": issues,
|
||||
"more_detail": more_detail,
|
||||
}
|
||||
with Session(engine) as session:
|
||||
issue = IssueReport(
|
||||
issues={
|
||||
"correctness": correctness,
|
||||
"issues": issues,
|
||||
"more_detail": more_detail,
|
||||
},
|
||||
issues=issue_dict,
|
||||
chat={
|
||||
"conv_id": conv_id,
|
||||
"chat_history": chat_history,
|
||||
@@ -83,4 +91,18 @@ class ReportIssue(BasePage):
|
||||
)
|
||||
session.add(issue)
|
||||
session.commit()
|
||||
|
||||
# forward feedback to KNet service
|
||||
try:
|
||||
data = {
|
||||
"feedback": json.dumps(issue_dict),
|
||||
"conv_id": conv_id,
|
||||
}
|
||||
print(data)
|
||||
response = requests.post(self.knet_endpoint, data=data)
|
||||
response.raise_for_status()
|
||||
print(response.text)
|
||||
except Exception as e:
|
||||
print("Error submitting Knet feedback:", e)
|
||||
|
||||
gr.Info("Thank you for your feedback")
|
||||
|
||||
@@ -516,7 +516,10 @@ class FullQAPipeline(BaseReasoning):
|
||||
use_rewrite: bool = False
|
||||
|
||||
def retrieve(
|
||||
self, message: str, history: list
|
||||
self,
|
||||
message: str,
|
||||
history: list,
|
||||
conv_id: str | None = None,
|
||||
) -> tuple[list[RetrievedDocument], list[Document]]:
|
||||
"""Retrieve the documents based on the message"""
|
||||
# if len(message) < self.trigger_context:
|
||||
@@ -539,7 +542,7 @@ class FullQAPipeline(BaseReasoning):
|
||||
|
||||
for idx, retriever in enumerate(self.retrievers):
|
||||
retriever_node = self._prepare_child(retriever, f"retriever_{idx}")
|
||||
retriever_docs = retriever_node(text=query)
|
||||
retriever_docs = retriever_node(text=query, conv_id=conv_id)
|
||||
|
||||
retriever_docs_text = []
|
||||
retriever_docs_plot = []
|
||||
@@ -701,7 +704,7 @@ class FullQAPipeline(BaseReasoning):
|
||||
rewrite = await self.rewrite_pipeline(question=message)
|
||||
message = rewrite.text
|
||||
|
||||
docs, infos = self.retrieve(message, history)
|
||||
docs, infos = self.retrieve(message, history, conv_id=conv_id)
|
||||
for _ in infos:
|
||||
self.report_output(_)
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -741,7 +744,7 @@ class FullQAPipeline(BaseReasoning):
|
||||
|
||||
print(f"Retrievers {self.retrievers}")
|
||||
# should populate the context
|
||||
docs, infos = self.retrieve(message, history)
|
||||
docs, infos = self.retrieve(message, history, conv_id=conv_id)
|
||||
print(f"Got {len(docs)} retrieved documents")
|
||||
yield from infos
|
||||
|
||||
@@ -896,7 +899,7 @@ class FullDecomposeQAPipeline(FullQAPipeline):
|
||||
f"<br>{message}<br><b>Answer</b><br>",
|
||||
)
|
||||
# should populate the context
|
||||
docs, infos = self.retrieve(message, history)
|
||||
docs, infos = self.retrieve(message, history, conv_id=conv_id)
|
||||
print(f"Got {len(docs)} retrieved documents")
|
||||
|
||||
yield from infos
|
||||
@@ -946,7 +949,7 @@ class FullDecomposeQAPipeline(FullQAPipeline):
|
||||
)
|
||||
|
||||
# should populate the context
|
||||
docs, infos = self.retrieve(message, history)
|
||||
docs, infos = self.retrieve(message, history, conv_id=conv_id)
|
||||
print(f"Got {len(docs)} retrieved documents")
|
||||
yield from infos
|
||||
|
||||
|
||||
Reference in New Issue
Block a user