diff --git a/libs/ktem/ktem/index/file/graph/pipelines.py b/libs/ktem/ktem/index/file/graph/pipelines.py index 30167533..a43ea313 100644 --- a/libs/ktem/ktem/index/file/graph/pipelines.py +++ b/libs/ktem/ktem/index/file/graph/pipelines.py @@ -315,6 +315,8 @@ class GraphRAGRetrieverPipeline(BaseFileIndexRetriever): def run( self, text: str, + *args, + **kwargs, ) -> list[RetrievedDocument]: if not self.file_ids: return [] diff --git a/libs/ktem/ktem/index/file/knet/pipelines.py b/libs/ktem/ktem/index/file/knet/pipelines.py index 9286481c..5cb4be8b 100644 --- a/libs/ktem/ktem/index/file/knet/pipelines.py +++ b/libs/ktem/ktem/index/file/knet/pipelines.py @@ -70,6 +70,8 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever): doc_ids: list of document ids to constraint the retrieval """ print("searching in doc_ids", doc_ids) + conv_id = kwargs.get("conv_id", "") + if not doc_ids: return [] @@ -78,6 +80,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever): "query": text, "collection": self.collection_name, "meta_filters": {"doc_name": doc_ids}, + "conv_id": conv_id, } params["meta_filters"] = json.dumps(params["meta_filters"]) response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params) diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py index f4047437..d17c44c5 100644 --- a/libs/ktem/ktem/pages/chat/report.py +++ b/libs/ktem/ktem/pages/chat/report.py @@ -1,6 +1,9 @@ +import json +import os from typing import Optional import gradio as gr +import requests from ktem.app import BasePage from ktem.db.models import IssueReport, engine from sqlmodel import Session @@ -8,6 +11,10 @@ from sqlmodel import Session class ReportIssue(BasePage): def __init__(self, app): + self.knet_endpoint = ( + os.environ.get("KN_ENDPOINT", "http://127.0.0.1:8081") + "/feedback" + ) + self._app = app self.on_building_ui() @@ -64,13 +71,14 @@ class ReportIssue(BasePage): else: print(f"Unknown selector type: {index.selector}") + issue_dict = { + "correctness": correctness, + "issues": issues, + "more_detail": more_detail, + } with Session(engine) as session: issue = IssueReport( - issues={ - "correctness": correctness, - "issues": issues, - "more_detail": more_detail, - }, + issues=issue_dict, chat={ "conv_id": conv_id, "chat_history": chat_history, @@ -83,4 +91,18 @@ class ReportIssue(BasePage): ) session.add(issue) session.commit() + + # forward feedback to KNet service + try: + data = { + "feedback": json.dumps(issue_dict), + "conv_id": conv_id, + } + print(data) + response = requests.post(self.knet_endpoint, data=data) + response.raise_for_status() + print(response.text) + except Exception as e: + print("Error submitting Knet feedback:", e) + gr.Info("Thank you for your feedback") diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 9353616c..eee3b44a 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -516,7 +516,10 @@ class FullQAPipeline(BaseReasoning): use_rewrite: bool = False def retrieve( - self, message: str, history: list + self, + message: str, + history: list, + conv_id: str | None = None, ) -> tuple[list[RetrievedDocument], list[Document]]: """Retrieve the documents based on the message""" # if len(message) < self.trigger_context: @@ -539,7 +542,7 @@ class FullQAPipeline(BaseReasoning): for idx, retriever in enumerate(self.retrievers): retriever_node = self._prepare_child(retriever, f"retriever_{idx}") - retriever_docs = retriever_node(text=query) + retriever_docs = retriever_node(text=query, conv_id=conv_id) retriever_docs_text = [] retriever_docs_plot = [] @@ -701,7 +704,7 @@ class FullQAPipeline(BaseReasoning): rewrite = await self.rewrite_pipeline(question=message) message = rewrite.text - docs, infos = self.retrieve(message, history) + docs, infos = self.retrieve(message, history, conv_id=conv_id) for _ in infos: self.report_output(_) await asyncio.sleep(0.1) @@ -741,7 +744,7 @@ class FullQAPipeline(BaseReasoning): print(f"Retrievers {self.retrievers}") # should populate the context - docs, infos = self.retrieve(message, history) + docs, infos = self.retrieve(message, history, conv_id=conv_id) print(f"Got {len(docs)} retrieved documents") yield from infos @@ -896,7 +899,7 @@ class FullDecomposeQAPipeline(FullQAPipeline): f"
{message}
Answer
", ) # should populate the context - docs, infos = self.retrieve(message, history) + docs, infos = self.retrieve(message, history, conv_id=conv_id) print(f"Got {len(docs)} retrieved documents") yield from infos @@ -946,7 +949,7 @@ class FullDecomposeQAPipeline(FullQAPipeline): ) # should populate the context - docs, infos = self.retrieve(message, history) + docs, infos = self.retrieve(message, history, conv_id=conv_id) print(f"Got {len(docs)} retrieved documents") yield from infos