diff --git a/libs/ktem/ktem/index/file/graph/pipelines.py b/libs/ktem/ktem/index/file/graph/pipelines.py
index 30167533..a43ea313 100644
--- a/libs/ktem/ktem/index/file/graph/pipelines.py
+++ b/libs/ktem/ktem/index/file/graph/pipelines.py
@@ -315,6 +315,8 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
def run(
self,
text: str,
+ *args,
+ **kwargs,
) -> list[RetrievedDocument]:
if not self.file_ids:
return []
diff --git a/libs/ktem/ktem/index/file/knet/pipelines.py b/libs/ktem/ktem/index/file/knet/pipelines.py
index 9286481c..5cb4be8b 100644
--- a/libs/ktem/ktem/index/file/knet/pipelines.py
+++ b/libs/ktem/ktem/index/file/knet/pipelines.py
@@ -70,6 +70,8 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
doc_ids: list of document ids to constraint the retrieval
"""
print("searching in doc_ids", doc_ids)
+ conv_id = kwargs.get("conv_id", "")
+
if not doc_ids:
return []
@@ -78,6 +80,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
"query": text,
"collection": self.collection_name,
"meta_filters": {"doc_name": doc_ids},
+ "conv_id": conv_id,
}
params["meta_filters"] = json.dumps(params["meta_filters"])
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)
diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py
index f4047437..d17c44c5 100644
--- a/libs/ktem/ktem/pages/chat/report.py
+++ b/libs/ktem/ktem/pages/chat/report.py
@@ -1,6 +1,9 @@
+import json
+import os
from typing import Optional
import gradio as gr
+import requests
from ktem.app import BasePage
from ktem.db.models import IssueReport, engine
from sqlmodel import Session
@@ -8,6 +11,10 @@ from sqlmodel import Session
class ReportIssue(BasePage):
def __init__(self, app):
+ self.knet_endpoint = (
+ os.environ.get("KN_ENDPOINT", "http://127.0.0.1:8081") + "/feedback"
+ )
+
self._app = app
self.on_building_ui()
@@ -64,13 +71,14 @@ class ReportIssue(BasePage):
else:
print(f"Unknown selector type: {index.selector}")
+ issue_dict = {
+ "correctness": correctness,
+ "issues": issues,
+ "more_detail": more_detail,
+ }
with Session(engine) as session:
issue = IssueReport(
- issues={
- "correctness": correctness,
- "issues": issues,
- "more_detail": more_detail,
- },
+ issues=issue_dict,
chat={
"conv_id": conv_id,
"chat_history": chat_history,
@@ -83,4 +91,18 @@ class ReportIssue(BasePage):
)
session.add(issue)
session.commit()
+
+ # forward feedback to KNet service
+ try:
+ data = {
+ "feedback": json.dumps(issue_dict),
+ "conv_id": conv_id,
+ }
+ print(data)
+ response = requests.post(self.knet_endpoint, data=data)
+ response.raise_for_status()
+ print(response.text)
+ except Exception as e:
+ print("Error submitting Knet feedback:", e)
+
gr.Info("Thank you for your feedback")
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index 9353616c..eee3b44a 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -516,7 +516,10 @@ class FullQAPipeline(BaseReasoning):
use_rewrite: bool = False
def retrieve(
- self, message: str, history: list
+ self,
+ message: str,
+ history: list,
+ conv_id: str | None = None,
) -> tuple[list[RetrievedDocument], list[Document]]:
"""Retrieve the documents based on the message"""
# if len(message) < self.trigger_context:
@@ -539,7 +542,7 @@ class FullQAPipeline(BaseReasoning):
for idx, retriever in enumerate(self.retrievers):
retriever_node = self._prepare_child(retriever, f"retriever_{idx}")
- retriever_docs = retriever_node(text=query)
+ retriever_docs = retriever_node(text=query, conv_id=conv_id)
retriever_docs_text = []
retriever_docs_plot = []
@@ -701,7 +704,7 @@ class FullQAPipeline(BaseReasoning):
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
- docs, infos = self.retrieve(message, history)
+ docs, infos = self.retrieve(message, history, conv_id=conv_id)
for _ in infos:
self.report_output(_)
await asyncio.sleep(0.1)
@@ -741,7 +744,7 @@ class FullQAPipeline(BaseReasoning):
print(f"Retrievers {self.retrievers}")
# should populate the context
- docs, infos = self.retrieve(message, history)
+ docs, infos = self.retrieve(message, history, conv_id=conv_id)
print(f"Got {len(docs)} retrieved documents")
yield from infos
@@ -896,7 +899,7 @@ class FullDecomposeQAPipeline(FullQAPipeline):
f"
{message}
Answer
",
)
# should populate the context
- docs, infos = self.retrieve(message, history)
+ docs, infos = self.retrieve(message, history, conv_id=conv_id)
print(f"Got {len(docs)} retrieved documents")
yield from infos
@@ -946,7 +949,7 @@ class FullDecomposeQAPipeline(FullQAPipeline):
)
# should populate the context
- docs, infos = self.retrieve(message, history)
+ docs, infos = self.retrieve(message, history, conv_id=conv_id)
print(f"Got {len(docs)} retrieved documents")
yield from infos