mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2026-02-23 19:49:37 +01:00
fix: add lang param in knet retriever
This commit is contained in:
@@ -74,6 +74,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
|
||||
"""
|
||||
print("searching in doc_ids", doc_ids)
|
||||
conv_id = kwargs.get("conv_id", "")
|
||||
lang = kwargs.get("lang", None)
|
||||
|
||||
if not doc_ids:
|
||||
return []
|
||||
@@ -87,6 +88,7 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
|
||||
"assert_query_type": self.pipeline_name,
|
||||
"retrieval_expansion": self.retrieval_expansion,
|
||||
"user_id": self.user_id,
|
||||
"lang": lang,
|
||||
}
|
||||
print(params)
|
||||
params["meta_filters"] = json.dumps(params["meta_filters"])
|
||||
|
||||
@@ -523,12 +523,14 @@ class FullQAPipeline(BaseReasoning):
|
||||
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
|
||||
trigger_context: int = 150
|
||||
use_rewrite: bool = False
|
||||
lang: str | None = None
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
message: str,
|
||||
history: list,
|
||||
conv_id: str | None = None,
|
||||
lang: str | None = None,
|
||||
) -> tuple[list[RetrievedDocument], list[Document]]:
|
||||
"""Retrieve the documents based on the message"""
|
||||
# if len(message) < self.trigger_context:
|
||||
@@ -551,7 +553,11 @@ class FullQAPipeline(BaseReasoning):
|
||||
|
||||
for idx, retriever in enumerate(self.retrievers):
|
||||
retriever_node = self._prepare_child(retriever, f"retriever_{idx}")
|
||||
retriever_docs = retriever_node(text=query, conv_id=conv_id)
|
||||
retriever_docs = retriever_node(
|
||||
text=query,
|
||||
conv_id=conv_id,
|
||||
lang=lang,
|
||||
)
|
||||
|
||||
retriever_docs_text = []
|
||||
retriever_docs_plot = []
|
||||
@@ -575,6 +581,7 @@ class FullQAPipeline(BaseReasoning):
|
||||
content=Render.collapsible_with_header(doc, open_collapsible=True),
|
||||
)
|
||||
for doc in docs
|
||||
if doc.metadata.get("type") not in FINAL_ANSWER_CONTENT_TYPES
|
||||
] + [
|
||||
Document(
|
||||
channel="plot",
|
||||
@@ -762,7 +769,12 @@ class FullQAPipeline(BaseReasoning):
|
||||
|
||||
print(f"Retrievers {self.retrievers}")
|
||||
# should populate the context
|
||||
docs, infos = self.retrieve(message, history, conv_id=conv_id)
|
||||
docs, infos = self.retrieve(
|
||||
message,
|
||||
history,
|
||||
conv_id=conv_id,
|
||||
lang=self.lang,
|
||||
)
|
||||
print(f"Got {len(docs)} retrieved documents")
|
||||
yield from infos
|
||||
|
||||
@@ -840,6 +852,9 @@ class FullQAPipeline(BaseReasoning):
|
||||
f"{prefix}.n_last_interactions"
|
||||
]
|
||||
|
||||
pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
||||
settings["reasoning.lang"], "English"
|
||||
)
|
||||
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
|
||||
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||
if pipeline.rewrite_pipeline:
|
||||
|
||||
Reference in New Issue
Block a user