diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 15dec7bedd..34e34f21c3 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -143,19 +143,16 @@ DEFAULT_CODE_INTERPRETER_TAGS = [("", "")] def get_citation_source_from_tool_result( - tool_name: str, - tool_params: dict, - tool_result: str, - tool_id: str = "" + tool_name: str, tool_params: dict, tool_result: str, tool_id: str = "" ) -> list[dict]: """ Parse a tool's result and convert it to source dicts for citation display. - + Follows the source format conventions from get_sources_from_items: - source: file/item info object with id, name, type - - document: list of document contents + - document: list of document contents - metadata: list of metadata objects with source, file_id, name fields - + Returns a list of sources (usually one, but query_knowledge_bases may return multiple). """ try: @@ -171,56 +168,68 @@ def get_citation_source_from_tool_result( snippet = result.get("snippet", "") documents.append(f"{title}\n{snippet}") - metadata.append({ - "source": link, - "name": title, - "url": link, - }) + metadata.append( + { + "source": link, + "name": title, + "url": link, + } + ) - return [{ - "source": {"name": "search_web", "id": "search_web"}, - "document": documents, - "metadata": metadata, - }] + return [ + { + "source": {"name": "search_web", "id": "search_web"}, + "document": documents, + "metadata": metadata, + } + ] elif tool_name == "view_knowledge_file": file_data = json.loads(tool_result) filename = file_data.get("filename", "Unknown File") file_id = file_data.get("id", "") knowledge_name = file_data.get("knowledge_name", "") - - return [{ - "source": { - "id": file_id, - "name": filename, - "type": "file", - }, - "document": [file_data.get("content", "")], - "metadata": [{ - "file_id": file_id, - "name": filename, - "source": filename, - **({"knowledge_name": knowledge_name} if knowledge_name else {}), - }], - }] + + return [ + { + "source": { + "id": file_id, + "name": filename, + "type": "file", + }, + "document": [file_data.get("content", "")], + "metadata": [ + { + "file_id": file_id, + "name": filename, + "source": filename, + **( + {"knowledge_name": knowledge_name} + if knowledge_name + else {} + ), + } + ], + } + ] elif tool_name == "query_knowledge_bases": chunks = json.loads(tool_result) - + # Group chunks by source for better citation display # Each unique source becomes a separate source entry sources_by_file = {} - + for chunk in chunks: source_name = chunk.get("source", "Unknown") file_id = chunk.get("file_id", "") note_id = chunk.get("note_id", "") chunk_type = chunk.get("type", "file") content = chunk.get("content", "") - + # Use file_id or note_id as the key key = file_id or note_id or source_name - + if key not in sources_by_file: sources_by_file[key] = { "source": { @@ -231,36 +240,46 @@ def get_citation_source_from_tool_result( "document": [], "metadata": [], } - + sources_by_file[key]["document"].append(content) - sources_by_file[key]["metadata"].append({ - "file_id": file_id, - "name": source_name, - "source": source_name, - **({"note_id": note_id} if note_id else {}), - }) - + sources_by_file[key]["metadata"].append( + { + "file_id": file_id, + "name": source_name, + "source": source_name, + **({"note_id": note_id} if note_id else {}), + } + ) + # Return all grouped sources as a list if sources_by_file: return list(sources_by_file.values()) - + # Empty result fallback return [] else: # Fallback for other tools - return [{ - "source": {"name": tool_name, "type": "tool", "id": tool_id or tool_name}, - "document": [str(tool_result)], - "metadata": [{"source": tool_name, "name": tool_name}], - }] + return [ + { + "source": { + "name": tool_name, + "type": "tool", + "id": tool_id or tool_name, + }, + "document": [str(tool_result)], + "metadata": [{"source": tool_name, "name": tool_name}], + } + ] except Exception as e: log.exception(f"Error parsing tool result for {tool_name}: {e}") - return [{ - "source": {"name": tool_name, "type": "tool"}, - "document": [str(tool_result)], - "metadata": [{"source": tool_name}], - }] + return [ + { + "source": {"name": tool_name, "type": "tool"}, + "document": [str(tool_result)], + "metadata": [{"source": tool_name}], + } + ] def apply_source_context_to_messages( @@ -297,13 +316,17 @@ def apply_source_context_to_messages( if RAG_SYSTEM_CONTEXT: return add_or_update_system_message( - rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message), + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, user_message + ), messages, append=True, ) else: return add_or_update_user_message( - rag_template(request.app.state.config.RAG_TEMPLATE, context_string, user_message), + rag_template( + request.app.state.config.RAG_TEMPLATE, context_string, user_message + ), messages, append=False, ) @@ -923,46 +946,52 @@ def get_image_urls(delta_images, request, metadata, user) -> list[str]: return image_urls -def inject_file_context_into_messages(messages: list) -> None: +def add_file_context(messages: list, chat_id: str, user) -> list: """ - Inject file context into each user message that has files. - Modifies messages in-place by prepending file info to message content. + Add file URLs to user messages for native function calling. """ - for message in messages: - if message.get("role") != "user": - continue - - files = message.get("files", []) - if not files: - continue - - # Build XML context for this message's files - file_entries = [] - for file in files: - if not file.get("url"): - continue - - attrs = [f'type="{file.get("type", "file")}"'] - if file.get("content_type"): - attrs.append(f'content_type="{file["content_type"]}"') - if file.get("name"): - attrs.append(f'name="{file["name"]}"') - attrs.append(f'url="{file["url"]}"') - file_entries.append(f'') - - if not file_entries: - continue - - files_context = "\n" + "\n".join(file_entries) + "\n\n\n" - - # Prepend to message content - content = message.get("content", "") - if isinstance(content, str): - message["content"] = files_context + content - elif isinstance(content, list): - # For multimodal content, prepend as text item - message["content"] = [{"type": "text", "text": files_context}] + content + if not chat_id or chat_id.startswith("local:"): + return messages + chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id) + if not chat: + return messages + + history = chat.chat.get("history", {}) + stored_messages = get_message_list( + history.get("messages", {}), history.get("currentId") + ) + stored_user_messages = [msg for msg in stored_messages if msg.get("role") == "user"] + + def format_file_tag(file): + attrs = f'type="{file.get("type", "file")}" url="{file["url"]}"' + if file.get("content_type"): + attrs += f' content_type="{file["content_type"]}"' + if file.get("name"): + attrs += f' name="{file["name"]}"' + return f"" + + user_messages = [msg for msg in messages if msg.get("role") == "user"] + + for message, stored_message in zip(user_messages, stored_user_messages): + files_with_urls = [ + file for file in stored_message.get("files", []) if file.get("url") + ] + if not files_with_urls: + continue + + file_tags = [format_file_tag(file) for file in files_with_urls] + file_context = ( + "\n" + "\n".join(file_tags) + "\n\n\n" + ) + + content = message.get("content", "") + if isinstance(content, list): + message["content"] = [{"type": "text", "text": file_context}] + content + else: + message["content"] = file_context + content + + return messages async def chat_image_generation_handler( @@ -1454,7 +1483,10 @@ async def process_chat_payload(request, form_data, user, metadata, model): user_message = get_last_user_message(form_data["messages"]) model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) - if model_knowledge and metadata.get("params", {}).get("function_calling") != "native": + if ( + model_knowledge + and metadata.get("params", {}).get("function_calling") != "native" + ): await event_emitter( { "type": "status", @@ -1767,8 +1799,21 @@ async def process_chat_payload(request, form_data, user, metadata, model): # Inject builtin tools for native function calling based on enabled features and model capability # Check if builtin_tools capability is enabled for this model (defaults to True if not specified) - builtin_tools_enabled = model.get("info", {}).get("meta", {}).get("capabilities", {}).get("builtin_tools", True) - if metadata.get("params", {}).get("function_calling") == "native" and builtin_tools_enabled: + builtin_tools_enabled = ( + model.get("info", {}) + .get("meta", {}) + .get("capabilities", {}) + .get("builtin_tools", True) + ) + if ( + metadata.get("params", {}).get("function_calling") == "native" + and builtin_tools_enabled + ): + # Add file context to user messages + chat_id = metadata.get("chat_id") + form_data["messages"] = add_file_context( + form_data.get("messages", []), chat_id, user + ) builtin_tools = get_builtin_tools( request, { @@ -1790,9 +1835,6 @@ async def process_chat_payload(request, form_data, user, metadata, model): {"type": "function", "function": tool.get("spec", {})} for tool in tools_dict.values() ] - # Inject file context into each user message that has files attached - inject_file_context_into_messages(form_data.get("messages", [])) - else: # If the function calling is not native, then call the tools function calling handler @@ -3279,13 +3321,21 @@ async def process_chat_response( ) # Extract citation sources from tool results - if tool_function_name in ["search_web", "view_knowledge_file", "query_knowledge_bases"] and tool_result: + if ( + tool_function_name + in [ + "search_web", + "view_knowledge_file", + "query_knowledge_bases", + ] + and tool_result + ): try: citation_sources = get_citation_source_from_tool_result( tool_name=tool_function_name, tool_params=tool_function_params, tool_result=tool_result, - tool_id=tool.get("tool_id", "") if tool else "" + tool_id=tool.get("tool_id", "") if tool else "", ) tool_call_sources.extend(citation_sources) except Exception as e: @@ -3325,7 +3375,10 @@ async def process_chat_response( user_msg = get_last_user_message(form_data["messages"]) if user_msg: form_data["messages"] = apply_source_context_to_messages( - request, form_data["messages"], tool_call_sources, user_msg + request, + form_data["messages"], + tool_call_sources, + user_msg, ) tool_call_sources.clear()