From 7e2ad7af7ada4cf71dedd45260840ac3163a3a29 Mon Sep 17 00:00:00 2001 From: trducng Date: Mon, 18 Nov 2024 03:09:47 +0000 Subject: [PATCH] fix: add lang param in knet retriever --- libs/ktem/ktem/index/file/knet/pipelines.py | 2 ++ libs/ktem/ktem/reasoning/simple.py | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/libs/ktem/ktem/index/file/knet/pipelines.py b/libs/ktem/ktem/index/file/knet/pipelines.py index 4b060d35..2820d372 100644 --- a/libs/ktem/ktem/index/file/knet/pipelines.py +++ b/libs/ktem/ktem/index/file/knet/pipelines.py @@ -74,6 +74,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever): """ print("searching in doc_ids", doc_ids) conv_id = kwargs.get("conv_id", "") + lang = kwargs.get("lang", None) if not doc_ids: return [] @@ -87,6 +88,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever): "assert_query_type": self.pipeline_name, "retrieval_expansion": self.retrieval_expansion, "user_id": self.user_id, + "lang": lang, } print(params) params["meta_filters"] = json.dumps(params["meta_filters"]) diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 973c5522..9be854a6 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -523,12 +523,14 @@ class FullQAPipeline(BaseReasoning): add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx() trigger_context: int = 150 use_rewrite: bool = False + lang: str | None = None def retrieve( self, message: str, history: list, conv_id: str | None = None, + lang: str | None = None, ) -> tuple[list[RetrievedDocument], list[Document]]: """Retrieve the documents based on the message""" # if len(message) < self.trigger_context: @@ -551,7 +553,11 @@ class FullQAPipeline(BaseReasoning): for idx, retriever in enumerate(self.retrievers): retriever_node = self._prepare_child(retriever, f"retriever_{idx}") - retriever_docs = retriever_node(text=query, conv_id=conv_id) + retriever_docs = retriever_node( + text=query, + conv_id=conv_id, + lang=lang, + ) retriever_docs_text = [] retriever_docs_plot = [] @@ -575,6 +581,7 @@ class FullQAPipeline(BaseReasoning): content=Render.collapsible_with_header(doc, open_collapsible=True), ) for doc in docs + if doc.metadata.get("type") not in FINAL_ANSWER_CONTENT_TYPES ] + [ Document( channel="plot", @@ -762,7 +769,12 @@ class FullQAPipeline(BaseReasoning): print(f"Retrievers {self.retrievers}") # should populate the context - docs, infos = self.retrieve(message, history, conv_id=conv_id) + docs, infos = self.retrieve( + message, + history, + conv_id=conv_id, + lang=self.lang, + ) print(f"Got {len(docs)} retrieved documents") yield from infos @@ -840,6 +852,9 @@ class FullQAPipeline(BaseReasoning): f"{prefix}.n_last_interactions" ] + pipeline.lang = SUPPORTED_LANGUAGE_MAP.get( + settings["reasoning.lang"], "English" + ) pipeline.trigger_context = settings[f"{prefix}.trigger_context"] pipeline.use_rewrite = states.get("app", {}).get("regen", False) if pipeline.rewrite_pipeline: