feat: add RAG setting panel

This commit is contained in:
trducng
2024-10-18 04:58:37 +00:00
parent 1aaa75e5c1
commit d1b0b1e831
6 changed files with 106 additions and 2 deletions

View File

@@ -213,3 +213,8 @@ pdfjs-viewer-element {
flex: 1;
overflow: auto;
}
#chat-input {
resize: vertical;
max-height: 200px;
}

View File

@@ -1,3 +1,4 @@
import os
from typing import Any
from ktem.index.file import FileIndex
@@ -5,6 +6,8 @@ from ktem.index.file import FileIndex
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .pipelines import KnetIndexingPipeline, KnetRetrievalPipeline
DEFAULT_RAG_PIPELINE = os.environ.get("KNET_RAG_PIPELINE", "Direct_Query")
class KnowledgeNetworkFileIndex(FileIndex):
@classmethod
@@ -43,5 +46,11 @@ class KnowledgeNetworkFileIndex(FileIndex):
# also set the collection_name for API call
obj.VS = None
obj.collection_name = f"kh_index_{self.id}"
obj.retrieval_expansion = settings.get("rag_settings", {}).get(
"retrieval_expansion", False
)
obj.pipeline_name = settings.get("rag_settings", {}).get(
"pipeline", DEFAULT_RAG_PIPELINE
)
return retrievers

View File

@@ -47,6 +47,8 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
collection_name: str = "default"
rerankers: Sequence[BaseReranking] = [LLMReranking.withx()]
retrieval_expansion: bool = False
pipeline_name: str | None = None
def encode_image_base64(self, image_path: str | Path) -> bytes | str:
"""Convert image to base64"""
@@ -81,7 +83,10 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever):
"collection": self.collection_name,
"meta_filters": {"doc_name": doc_ids},
"conv_id": conv_id,
"assert_query_type": self.pipeline_name,
"retrieval_expansion": self.retrieval_expansion,
}
print(params)
params["meta_filters"] = json.dumps(params["meta_filters"])
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)
metadata_translation = {

View File

@@ -24,6 +24,7 @@ from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE
from .control import ConversationControl
from .rag_setting import RAGSetting
from .report import ReportIssue
DEFAULT_SETTING = "(default)"
@@ -95,8 +96,11 @@ class ChatPage(BasePage):
self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)
with gr.Accordion(label="Retrieval Settings") as _:
self.retrieval_settings = RAGSetting(self._app)
if len(self._app.index_manager.indices) > 0:
with gr.Accordion(label="Quick Upload") as _:
with gr.Accordion(label="Quick Upload", open=False) as _:
self.quick_file_upload = File(
file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()),
file_count="multiple",
@@ -180,6 +184,7 @@ class ChatPage(BasePage):
self._reasoning_type,
self._llm_type,
self.chat_state,
self.retrieval_settings.setting_state,
self._app.user_id,
]
+ self._indices_input,
@@ -251,6 +256,7 @@ class ChatPage(BasePage):
self._reasoning_type,
self._llm_type,
self.chat_state,
self.retrieval_settings.setting_state,
self._app.user_id,
]
+ self._indices_input,
@@ -648,6 +654,7 @@ class ChatPage(BasePage):
session_reasoning_type: str,
session_llm: str,
state: dict,
rag_state: dict,
user_id: int,
*selecteds,
):
@@ -665,6 +672,8 @@ class ChatPage(BasePage):
# override reasoning_mode by temporary chat page state
print("Session reasoning type", session_reasoning_type)
print("Session LLM", session_llm)
print("RAG settings", rag_state)
reasoning_mode = (
settings["reasoning.use"]
if session_reasoning_type in (DEFAULT_SETTING, None)
@@ -675,6 +684,8 @@ class ChatPage(BasePage):
reasoning_id = reasoning_cls.get_info()["id"]
settings = deepcopy(settings)
settings["rag_settings"] = rag_state
llm_setting_key = f"reasoning.options.{reasoning_id}.llm"
if llm_setting_key in settings and session_llm not in (DEFAULT_SETTING, None):
settings[llm_setting_key] = session_llm
@@ -711,6 +722,7 @@ class ChatPage(BasePage):
reasoning_type,
llm_type,
state,
rag_state,
user_id,
*selecteds,
):
@@ -722,7 +734,7 @@ class ChatPage(BasePage):
# construct the pipeline
pipeline, reasoning_state = self.create_pipeline(
settings, reasoning_type, llm_type, state, user_id, *selecteds
settings, reasoning_type, llm_type, state, rag_state, user_id, *selecteds
)
print("Reasoning state", reasoning_state)
pipeline.set_output_queue(queue)
@@ -785,6 +797,7 @@ class ChatPage(BasePage):
reasoning_type,
llm_type,
state,
rag_state,
user_id,
*selecteds,
):
@@ -802,6 +815,7 @@ class ChatPage(BasePage):
reasoning_type,
llm_type,
state,
rag_state,
user_id,
*selecteds,
):

View File

@@ -26,6 +26,7 @@ class ChatPanel(BasePage):
scale=15,
container=False,
max_lines=10,
elem_id="chat-input",
)
self.submit_btn = gr.Button(
value="Send",

View File

@@ -0,0 +1,70 @@
import logging
import os
import gradio as gr
import requests
from ktem.app import BasePage
logger = logging.getLogger(__name__)
DEFAULT_KNET_ENDPOINT = "http://127.0.0.1:8081"
KNET_ENDPOINT = os.environ.get("KN_ENDPOINT", DEFAULT_KNET_ENDPOINT)
class RAGSetting(BasePage):
"""Manage RAG settings from KNet"""
def __init__(self, app):
self._app = app
self.setting_state = gr.State(value={})
self.on_building_ui()
self.on_register_events()
def get_pipelines(self):
"""Retrieve pipeline list from KNet endpoint"""
try:
response = requests.get(KNET_ENDPOINT + "/query_type", timeout=5)
if response.status_code == 200:
output = response.json()
return [item["name"] for item in output["pipelines"]]
else:
raise IOError(f"{response.status_code}: {response.text}")
except Exception as e:
logger.error(f"Failed to retrieve KNet pipelines: {e}")
return []
def on_building_ui(self):
pipline_options = self.get_pipelines()
self.pipeline_select = gr.Dropdown(
label="Pipeline",
choices=pipline_options,
value=pipline_options[0] if pipline_options else None,
container=False,
interactive=True,
)
self.retrieval_expansion = gr.Checkbox(
label="Enable retrieval expansion",
value=False,
container=False,
)
def store_setting_state(self, pipeline, retrieval_expansion):
return {
"pipeline": pipeline,
"retrieval_expansion": retrieval_expansion,
}
def on_register_events(self):
gr.on(
triggers=[
self.pipeline_select.change,
self.retrieval_expansion.change,
],
fn=self.store_setting_state,
inputs=[
self.pipeline_select,
self.retrieval_expansion,
],
outputs=self.setting_state,
)