diff --git a/libs/ktem/ktem/assets/css/main.css b/libs/ktem/ktem/assets/css/main.css index 9986d49e..4fe2019d 100644 --- a/libs/ktem/ktem/assets/css/main.css +++ b/libs/ktem/ktem/assets/css/main.css @@ -213,3 +213,8 @@ pdfjs-viewer-element { flex: 1; overflow: auto; } + +#chat-input { + resize: vertical; + max-height: 200px; +} diff --git a/libs/ktem/ktem/index/file/knet/knet_index.py b/libs/ktem/ktem/index/file/knet/knet_index.py index f98e7a09..f1cf4e59 100644 --- a/libs/ktem/ktem/index/file/knet/knet_index.py +++ b/libs/ktem/ktem/index/file/knet/knet_index.py @@ -1,3 +1,4 @@ +import os from typing import Any from ktem.index.file import FileIndex @@ -5,6 +6,8 @@ from ktem.index.file import FileIndex from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever from .pipelines import KnetIndexingPipeline, KnetRetrievalPipeline +DEFAULT_RAG_PIPELINE = os.environ.get("KNET_RAG_PIPELINE", "Direct_Query") + class KnowledgeNetworkFileIndex(FileIndex): @classmethod @@ -43,5 +46,11 @@ class KnowledgeNetworkFileIndex(FileIndex): # also set the collection_name for API call obj.VS = None obj.collection_name = f"kh_index_{self.id}" + obj.retrieval_expansion = settings.get("rag_settings", {}).get( + "retrieval_expansion", False + ) + obj.pipeline_name = settings.get("rag_settings", {}).get( + "pipeline", DEFAULT_RAG_PIPELINE + ) return retrievers diff --git a/libs/ktem/ktem/index/file/knet/pipelines.py b/libs/ktem/ktem/index/file/knet/pipelines.py index 5cb4be8b..0e2f0cee 100644 --- a/libs/ktem/ktem/index/file/knet/pipelines.py +++ b/libs/ktem/ktem/index/file/knet/pipelines.py @@ -47,6 +47,8 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever): collection_name: str = "default" rerankers: Sequence[BaseReranking] = [LLMReranking.withx()] + retrieval_expansion: bool = False + pipeline_name: str | None = None def encode_image_base64(self, image_path: str | Path) -> bytes | str: """Convert image to base64""" @@ -81,7 +83,10 @@ class KnetRetrievalPipeline(BaseFileIndexRetriever): "collection": self.collection_name, "meta_filters": {"doc_name": doc_ids}, "conv_id": conv_id, + "assert_query_type": self.pipeline_name, + "retrieval_expansion": self.retrieval_expansion, } + print(params) params["meta_filters"] = json.dumps(params["meta_filters"]) response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params) metadata_translation = { diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index f5a3cf37..a756d66e 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -24,6 +24,7 @@ from .chat_panel import ChatPanel from .chat_suggestion import ChatSuggestion from .common import STATE from .control import ConversationControl +from .rag_setting import RAGSetting from .report import ReportIssue DEFAULT_SETTING = "(default)" @@ -95,8 +96,11 @@ class ChatPage(BasePage): self._indices_input.append(gr_index) setattr(self, f"_index_{index.id}", index_ui) + with gr.Accordion(label="Retrieval Settings") as _: + self.retrieval_settings = RAGSetting(self._app) + if len(self._app.index_manager.indices) > 0: - with gr.Accordion(label="Quick Upload") as _: + with gr.Accordion(label="Quick Upload", open=False) as _: self.quick_file_upload = File( file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()), file_count="multiple", @@ -180,6 +184,7 @@ class ChatPage(BasePage): self._reasoning_type, self._llm_type, self.chat_state, + self.retrieval_settings.setting_state, self._app.user_id, ] + self._indices_input, @@ -251,6 +256,7 @@ class ChatPage(BasePage): self._reasoning_type, self._llm_type, self.chat_state, + self.retrieval_settings.setting_state, self._app.user_id, ] + self._indices_input, @@ -648,6 +654,7 @@ class ChatPage(BasePage): session_reasoning_type: str, session_llm: str, state: dict, + rag_state: dict, user_id: int, *selecteds, ): @@ -665,6 +672,8 @@ class ChatPage(BasePage): # override reasoning_mode by temporary chat page state print("Session reasoning type", session_reasoning_type) print("Session LLM", session_llm) + print("RAG settings", rag_state) + reasoning_mode = ( settings["reasoning.use"] if session_reasoning_type in (DEFAULT_SETTING, None) @@ -675,6 +684,8 @@ class ChatPage(BasePage): reasoning_id = reasoning_cls.get_info()["id"] settings = deepcopy(settings) + settings["rag_settings"] = rag_state + llm_setting_key = f"reasoning.options.{reasoning_id}.llm" if llm_setting_key in settings and session_llm not in (DEFAULT_SETTING, None): settings[llm_setting_key] = session_llm @@ -711,6 +722,7 @@ class ChatPage(BasePage): reasoning_type, llm_type, state, + rag_state, user_id, *selecteds, ): @@ -722,7 +734,7 @@ class ChatPage(BasePage): # construct the pipeline pipeline, reasoning_state = self.create_pipeline( - settings, reasoning_type, llm_type, state, user_id, *selecteds + settings, reasoning_type, llm_type, state, rag_state, user_id, *selecteds ) print("Reasoning state", reasoning_state) pipeline.set_output_queue(queue) @@ -785,6 +797,7 @@ class ChatPage(BasePage): reasoning_type, llm_type, state, + rag_state, user_id, *selecteds, ): @@ -802,6 +815,7 @@ class ChatPage(BasePage): reasoning_type, llm_type, state, + rag_state, user_id, *selecteds, ): diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index 80700b0f..d84b5287 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -26,6 +26,7 @@ class ChatPanel(BasePage): scale=15, container=False, max_lines=10, + elem_id="chat-input", ) self.submit_btn = gr.Button( value="Send", diff --git a/libs/ktem/ktem/pages/chat/rag_setting.py b/libs/ktem/ktem/pages/chat/rag_setting.py new file mode 100644 index 00000000..3c4420d5 --- /dev/null +++ b/libs/ktem/ktem/pages/chat/rag_setting.py @@ -0,0 +1,70 @@ +import logging +import os + +import gradio as gr +import requests +from ktem.app import BasePage + +logger = logging.getLogger(__name__) + + +DEFAULT_KNET_ENDPOINT = "http://127.0.0.1:8081" +KNET_ENDPOINT = os.environ.get("KN_ENDPOINT", DEFAULT_KNET_ENDPOINT) + + +class RAGSetting(BasePage): + """Manage RAG settings from KNet""" + + def __init__(self, app): + self._app = app + self.setting_state = gr.State(value={}) + self.on_building_ui() + self.on_register_events() + + def get_pipelines(self): + """Retrieve pipeline list from KNet endpoint""" + try: + response = requests.get(KNET_ENDPOINT + "/query_type", timeout=5) + if response.status_code == 200: + output = response.json() + return [item["name"] for item in output["pipelines"]] + else: + raise IOError(f"{response.status_code}: {response.text}") + except Exception as e: + logger.error(f"Failed to retrieve KNet pipelines: {e}") + return [] + + def on_building_ui(self): + pipline_options = self.get_pipelines() + self.pipeline_select = gr.Dropdown( + label="Pipeline", + choices=pipline_options, + value=pipline_options[0] if pipline_options else None, + container=False, + interactive=True, + ) + self.retrieval_expansion = gr.Checkbox( + label="Enable retrieval expansion", + value=False, + container=False, + ) + + def store_setting_state(self, pipeline, retrieval_expansion): + return { + "pipeline": pipeline, + "retrieval_expansion": retrieval_expansion, + } + + def on_register_events(self): + gr.on( + triggers=[ + self.pipeline_select.change, + self.retrieval_expansion.change, + ], + fn=self.store_setting_state, + inputs=[ + self.pipeline_select, + self.retrieval_expansion, + ], + outputs=self.setting_state, + )