fix: add lang param in knet retriever

This commit is contained in:
trducng
2024-11-18 03:09:47 +00:00
parent e3c4e39db5
commit 7e2ad7af7a
2 changed files with 19 additions and 2 deletions

View File

@@ -74,6 +74,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
"""
print("searching in doc_ids", doc_ids)
conv_id = kwargs.get("conv_id", "")
lang = kwargs.get("lang", None)
if not doc_ids:
return []
@@ -87,6 +88,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
"assert_query_type": self.pipeline_name,
"retrieval_expansion": self.retrieval_expansion,
"user_id": self.user_id,
"lang": lang,
}
print(params)
params["meta_filters"] = json.dumps(params["meta_filters"])

View File

@@ -523,12 +523,14 @@ class FullQAPipeline(BaseReasoning):
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
trigger_context: int = 150
use_rewrite: bool = False
lang: str | None = None
def retrieve(
self,
message: str,
history: list,
conv_id: str | None = None,
lang: str | None = None,
) -> tuple[list[RetrievedDocument], list[Document]]:
"""Retrieve the documents based on the message"""
# if len(message) < self.trigger_context:
@@ -551,7 +553,11 @@ class FullQAPipeline(BaseReasoning):
for idx, retriever in enumerate(self.retrievers):
retriever_node = self._prepare_child(retriever, f"retriever_{idx}")
retriever_docs = retriever_node(text=query, conv_id=conv_id)
retriever_docs = retriever_node(
text=query,
conv_id=conv_id,
lang=lang,
)
retriever_docs_text = []
retriever_docs_plot = []
@@ -575,6 +581,7 @@ class FullQAPipeline(BaseReasoning):
content=Render.collapsible_with_header(doc, open_collapsible=True),
)
for doc in docs
if doc.metadata.get("type") not in FINAL_ANSWER_CONTENT_TYPES
] + [
Document(
channel="plot",
@@ -762,7 +769,12 @@ class FullQAPipeline(BaseReasoning):
print(f"Retrievers {self.retrievers}")
# should populate the context
docs, infos = self.retrieve(message, history, conv_id=conv_id)
docs, infos = self.retrieve(
message,
history,
conv_id=conv_id,
lang=self.lang,
)
print(f"Got {len(docs)} retrieved documents")
yield from infos
@@ -840,6 +852,9 @@ class FullQAPipeline(BaseReasoning):
f"{prefix}.n_last_interactions"
]
pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English"
)
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
if pipeline.rewrite_pipeline: