diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index a756d66e..9edebcce 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -24,6 +24,7 @@ from .chat_panel import ChatPanel from .chat_suggestion import ChatSuggestion from .common import STATE from .control import ConversationControl +from .memory import MemoryPage from .rag_setting import RAGSetting from .report import ReportIssue @@ -99,6 +100,9 @@ class ChatPage(BasePage): with gr.Accordion(label="Retrieval Settings") as _: self.retrieval_settings = RAGSetting(self._app) + # long-term memory page + self.long_term_memory_btn = gr.Button("Long-term Memory") + if len(self._app.index_manager.indices) > 0: with gr.Accordion(label="Quick Upload", open=False) as _: self.quick_file_upload = File( @@ -147,6 +151,11 @@ class ChatPage(BasePage): with gr.Column( scale=INFO_PANEL_SCALES[False], elem_id="chat-info-panel" ) as self.info_column: + with gr.Accordion( + label="Long-term Memory", open=True, visible=False + ) as self.memory_accordion: + self.memory_page = MemoryPage(self._app) + with gr.Accordion(label="Information panel", open=True): self.modal = gr.HTML("
") self.plot_panel = gr.Plot(visible=False) @@ -433,6 +442,26 @@ class ChatPage(BasePage): fn=None, inputs=None, outputs=None, js=pdfview_js ) + # long-term memory page toggle + self.long_term_memory_btn.click( + lambda: (gr.update(visible=True), gr.update(scale=INFO_PANEL_SCALES[True])), + outputs=[ + self.memory_accordion, + self.info_column, + ], + ).then( + self.memory_page.list_memories, + inputs=[self._app.user_id], + outputs=[self.memory_page.memories], + ) + self.memory_page.close_button.click( + lambda: ( + gr.update(visible=False), + gr.update(visible=False), + ), + outputs=[self.memory_accordion, self.memory_page.memory_detail_panel], + ) + # evidence display on message selection self.chat_panel.chatbot.select( self.message_selected, diff --git a/libs/ktem/ktem/pages/chat/memory.py b/libs/ktem/ktem/pages/chat/memory.py new file mode 100644 index 00000000..82cba6d6 --- /dev/null +++ b/libs/ktem/ktem/pages/chat/memory.py @@ -0,0 +1,157 @@ +import logging +import os + +import gradio as gr +import pandas as pd +import requests +from ktem.app import BasePage + +logger = logging.getLogger(__name__) + + +DEFAULT_KNET_ENDPOINT = "http://127.0.0.1:8081" +KNET_ENDPOINT = os.environ.get("KN_ENDPOINT", DEFAULT_KNET_ENDPOINT) +LONG_TERM_MEMORY_COLLECTION = "long_term_memory" + + +class MemoryPage(BasePage): + """Manage RAG long-term memory from KNet""" + + def __init__(self, app): + self._app = app + self.memory_list = gr.State(value=[]) + self.on_building_ui() + self.on_register_events() + + def list_memories(self, user_id): + """Retrieve memory list from KNet endpoint""" + memory_list = [] + try: + params = {"user_id": user_id, "collection": LONG_TERM_MEMORY_COLLECTION} + + response = requests.get( + KNET_ENDPOINT + "/retrieve_long_term_memory", params=params, timeout=5 + ) + if response.status_code == 200: + output = response.json() + memory_list = output["knowledges"] + print(params, memory_list) + else: + raise IOError(f"{response.status_code}: {response.text}") + except Exception as e: + logger.error(f"Failed to retrieve KNet memory: {e}") + + if memory_list: + memory_list = pd.DataFrame.from_records( + memory_list, + columns=["id", "text"], + ) + else: + memory_list = pd.DataFrame.from_records( + [ + { + "id": "-", + "text": "-", + } + ], + columns=["id", "text"], + ) + return memory_list + + def delete_memory(self, memory_id): + """Delete memory from KNet endpoint""" + try: + params = {"ids": [memory_id], "collection": LONG_TERM_MEMORY_COLLECTION} + + response = requests.delete( + KNET_ENDPOINT + "/delete_long_term_memory", params=params, timeout=20 + ) + print(params, response.text) + if response.status_code == 200: + gr.Info("Memory deleted successfully") + else: + raise IOError(f"{response.status_code}: {response.text}") + except Exception as e: + raise gr.Error(f"Failed to delete KNet memory: {e}") + + def interact_memory_list(self, memory_list, ev: gr.SelectData): + if (ev.value == "-" and ev.index[0] == 0) or not ev.selected: + return "", "", gr.update(visible=False) + + return ( + memory_list["id"][ev.index[0]], + memory_list["text"][ev.index[0]], + gr.update(visible=True), + ) + + def on_building_ui(self): + self.memories = gr.Dataframe( + headers=[ + "id", + "text", + ], + column_widths=["20%", "80%"], + interactive=False, + ) + + with gr.Row(visible=False) as self.memory_detail_panel: + self.selected_memory_id = gr.State(value="") + self.selected_memory_text = gr.Textbox( + "Memory", + interactive=False, + container=False, + scale=2, + ) + self.delete_button = gr.Button( + "Delete", + variant="stop", + scale=1, + ) + + with gr.Row(): + self.refresh_button = gr.Button( + "Refresh", + variant="secondary", + ) + self.close_button = gr.Button( + "Close", + variant="secondary", + ) + + def on_register_events(self): + gr.on( + triggers=[ + self.memories.select, + ], + fn=self.interact_memory_list, + inputs=[ + self.memories, + ], + outputs=[ + self.selected_memory_id, + self.selected_memory_text, + self.memory_detail_panel, + ], + ) + + gr.on( + triggers=[ + self.delete_button.click, + ], + fn=self.delete_memory, + inputs=[ + self.selected_memory_id, + ], + ).then( + self.list_memories, + inputs=[ + self._app.user_id, + ], + outputs=[ + self.memories, + ], + ) + + self.refresh_button.click( + self.list_memories, inputs=[self._app.user_id], outputs=[self.memories] + )