mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2026-02-23 19:49:37 +01:00
feat: add RAG setting panel
This commit is contained in:
@@ -213,3 +213,8 @@ pdfjs-viewer-element {
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
#chat-input {
|
||||
resize: vertical;
|
||||
max-height: 200px;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from ktem.index.file import FileIndex
|
||||
@@ -5,6 +6,8 @@ from ktem.index.file import FileIndex
|
||||
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
from .pipelines import KnetIndexingPipeline, KnetRetrievalPipeline
|
||||
|
||||
DEFAULT_RAG_PIPELINE = os.environ.get("KNET_RAG_PIPELINE", "Direct_Query")
|
||||
|
||||
|
||||
class KnowledgeNetworkFileIndex(FileIndex):
|
||||
@classmethod
|
||||
@@ -43,5 +46,11 @@ class KnowledgeNetworkFileIndex(FileIndex):
|
||||
# also set the collection_name for API call
|
||||
obj.VS = None
|
||||
obj.collection_name = f"kh_index_{self.id}"
|
||||
obj.retrieval_expansion = settings.get("rag_settings", {}).get(
|
||||
"retrieval_expansion", False
|
||||
)
|
||||
obj.pipeline_name = settings.get("rag_settings", {}).get(
|
||||
"pipeline", DEFAULT_RAG_PIPELINE
|
||||
)
|
||||
|
||||
return retrievers
|
||||
|
||||
@@ -47,6 +47,8 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
|
||||
|
||||
collection_name: str = "default"
|
||||
rerankers: Sequence[BaseReranking] = [LLMReranking.withx()]
|
||||
retrieval_expansion: bool = False
|
||||
pipeline_name: str | None = None
|
||||
|
||||
def encode_image_base64(self, image_path: str | Path) -> bytes | str:
|
||||
"""Convert image to base64"""
|
||||
@@ -81,7 +83,10 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
|
||||
"collection": self.collection_name,
|
||||
"meta_filters": {"doc_name": doc_ids},
|
||||
"conv_id": conv_id,
|
||||
"assert_query_type": self.pipeline_name,
|
||||
"retrieval_expansion": self.retrieval_expansion,
|
||||
}
|
||||
print(params)
|
||||
params["meta_filters"] = json.dumps(params["meta_filters"])
|
||||
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)
|
||||
metadata_translation = {
|
||||
|
||||
@@ -24,6 +24,7 @@ from .chat_panel import ChatPanel
|
||||
from .chat_suggestion import ChatSuggestion
|
||||
from .common import STATE
|
||||
from .control import ConversationControl
|
||||
from .rag_setting import RAGSetting
|
||||
from .report import ReportIssue
|
||||
|
||||
DEFAULT_SETTING = "(default)"
|
||||
@@ -95,8 +96,11 @@ class ChatPage(BasePage):
|
||||
self._indices_input.append(gr_index)
|
||||
setattr(self, f"_index_{index.id}", index_ui)
|
||||
|
||||
with gr.Accordion(label="Retrieval Settings") as _:
|
||||
self.retrieval_settings = RAGSetting(self._app)
|
||||
|
||||
if len(self._app.index_manager.indices) > 0:
|
||||
with gr.Accordion(label="Quick Upload") as _:
|
||||
with gr.Accordion(label="Quick Upload", open=False) as _:
|
||||
self.quick_file_upload = File(
|
||||
file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()),
|
||||
file_count="multiple",
|
||||
@@ -180,6 +184,7 @@ class ChatPage(BasePage):
|
||||
self._reasoning_type,
|
||||
self._llm_type,
|
||||
self.chat_state,
|
||||
self.retrieval_settings.setting_state,
|
||||
self._app.user_id,
|
||||
]
|
||||
+ self._indices_input,
|
||||
@@ -251,6 +256,7 @@ class ChatPage(BasePage):
|
||||
self._reasoning_type,
|
||||
self._llm_type,
|
||||
self.chat_state,
|
||||
self.retrieval_settings.setting_state,
|
||||
self._app.user_id,
|
||||
]
|
||||
+ self._indices_input,
|
||||
@@ -648,6 +654,7 @@ class ChatPage(BasePage):
|
||||
session_reasoning_type: str,
|
||||
session_llm: str,
|
||||
state: dict,
|
||||
rag_state: dict,
|
||||
user_id: int,
|
||||
*selecteds,
|
||||
):
|
||||
@@ -665,6 +672,8 @@ class ChatPage(BasePage):
|
||||
# override reasoning_mode by temporary chat page state
|
||||
print("Session reasoning type", session_reasoning_type)
|
||||
print("Session LLM", session_llm)
|
||||
print("RAG settings", rag_state)
|
||||
|
||||
reasoning_mode = (
|
||||
settings["reasoning.use"]
|
||||
if session_reasoning_type in (DEFAULT_SETTING, None)
|
||||
@@ -675,6 +684,8 @@ class ChatPage(BasePage):
|
||||
reasoning_id = reasoning_cls.get_info()["id"]
|
||||
|
||||
settings = deepcopy(settings)
|
||||
settings["rag_settings"] = rag_state
|
||||
|
||||
llm_setting_key = f"reasoning.options.{reasoning_id}.llm"
|
||||
if llm_setting_key in settings and session_llm not in (DEFAULT_SETTING, None):
|
||||
settings[llm_setting_key] = session_llm
|
||||
@@ -711,6 +722,7 @@ class ChatPage(BasePage):
|
||||
reasoning_type,
|
||||
llm_type,
|
||||
state,
|
||||
rag_state,
|
||||
user_id,
|
||||
*selecteds,
|
||||
):
|
||||
@@ -722,7 +734,7 @@ class ChatPage(BasePage):
|
||||
|
||||
# construct the pipeline
|
||||
pipeline, reasoning_state = self.create_pipeline(
|
||||
settings, reasoning_type, llm_type, state, user_id, *selecteds
|
||||
settings, reasoning_type, llm_type, state, rag_state, user_id, *selecteds
|
||||
)
|
||||
print("Reasoning state", reasoning_state)
|
||||
pipeline.set_output_queue(queue)
|
||||
@@ -785,6 +797,7 @@ class ChatPage(BasePage):
|
||||
reasoning_type,
|
||||
llm_type,
|
||||
state,
|
||||
rag_state,
|
||||
user_id,
|
||||
*selecteds,
|
||||
):
|
||||
@@ -802,6 +815,7 @@ class ChatPage(BasePage):
|
||||
reasoning_type,
|
||||
llm_type,
|
||||
state,
|
||||
rag_state,
|
||||
user_id,
|
||||
*selecteds,
|
||||
):
|
||||
|
||||
@@ -26,6 +26,7 @@ class ChatPanel(BasePage):
|
||||
scale=15,
|
||||
container=False,
|
||||
max_lines=10,
|
||||
elem_id="chat-input",
|
||||
)
|
||||
self.submit_btn = gr.Button(
|
||||
value="Send",
|
||||
|
||||
70
libs/ktem/ktem/pages/chat/rag_setting.py
Normal file
70
libs/ktem/ktem/pages/chat/rag_setting.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
from ktem.app import BasePage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_KNET_ENDPOINT = "http://127.0.0.1:8081"
|
||||
KNET_ENDPOINT = os.environ.get("KN_ENDPOINT", DEFAULT_KNET_ENDPOINT)
|
||||
|
||||
|
||||
class RAGSetting(BasePage):
|
||||
"""Manage RAG settings from KNet"""
|
||||
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.setting_state = gr.State(value={})
|
||||
self.on_building_ui()
|
||||
self.on_register_events()
|
||||
|
||||
def get_pipelines(self):
|
||||
"""Retrieve pipeline list from KNet endpoint"""
|
||||
try:
|
||||
response = requests.get(KNET_ENDPOINT + "/query_type", timeout=5)
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
return [item["name"] for item in output["pipelines"]]
|
||||
else:
|
||||
raise IOError(f"{response.status_code}: {response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve KNet pipelines: {e}")
|
||||
return []
|
||||
|
||||
def on_building_ui(self):
|
||||
pipline_options = self.get_pipelines()
|
||||
self.pipeline_select = gr.Dropdown(
|
||||
label="Pipeline",
|
||||
choices=pipline_options,
|
||||
value=pipline_options[0] if pipline_options else None,
|
||||
container=False,
|
||||
interactive=True,
|
||||
)
|
||||
self.retrieval_expansion = gr.Checkbox(
|
||||
label="Enable retrieval expansion",
|
||||
value=False,
|
||||
container=False,
|
||||
)
|
||||
|
||||
def store_setting_state(self, pipeline, retrieval_expansion):
|
||||
return {
|
||||
"pipeline": pipeline,
|
||||
"retrieval_expansion": retrieval_expansion,
|
||||
}
|
||||
|
||||
def on_register_events(self):
|
||||
gr.on(
|
||||
triggers=[
|
||||
self.pipeline_select.change,
|
||||
self.retrieval_expansion.change,
|
||||
],
|
||||
fn=self.store_setting_state,
|
||||
inputs=[
|
||||
self.pipeline_select,
|
||||
self.retrieval_expansion,
|
||||
],
|
||||
outputs=self.setting_state,
|
||||
)
|
||||
Reference in New Issue
Block a user