mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 20:07:49 +01:00
enh/refac: url input handling
This commit is contained in:
@@ -6,6 +6,7 @@ import requests
|
||||
import hashlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
import re
|
||||
|
||||
from urllib.parse import quote
|
||||
from huggingface_hub import snapshot_download
|
||||
@@ -16,6 +17,7 @@ from langchain_core.documents import Document
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.files import Files
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
@@ -27,6 +29,9 @@ from open_webui.retrieval.vector.main import GetResult
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.utils.misc import get_message_list
|
||||
|
||||
from open_webui.retrieval.web.utils import get_web_loader
|
||||
from open_webui.retrieval.loaders.youtube import YoutubeLoader
|
||||
|
||||
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
@@ -49,6 +54,33 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
def is_youtube_url(url: str) -> bool:
|
||||
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
|
||||
return re.match(youtube_regex, url) is not None
|
||||
|
||||
|
||||
def get_loader(request, url: str):
|
||||
if is_youtube_url(url):
|
||||
return YoutubeLoader(
|
||||
url,
|
||||
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
)
|
||||
else:
|
||||
return get_web_loader(
|
||||
url,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
)
|
||||
|
||||
|
||||
def get_content_from_url(request, url: str) -> str:
|
||||
loader = get_loader(request, url)
|
||||
docs = loader.load()
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
return content, docs
|
||||
|
||||
|
||||
class VectorSearchRetriever(BaseRetriever):
|
||||
collection_name: Any
|
||||
embedding_function: Any
|
||||
@@ -571,6 +603,13 @@ def get_sources_from_items(
|
||||
"metadatas": [[{"file_id": chat.id, "name": chat.title}]],
|
||||
}
|
||||
|
||||
elif item.get("type") == "url":
|
||||
content, docs = get_content_from_url(request, item.get("url"))
|
||||
if docs:
|
||||
query_result = {
|
||||
"documents": [[content]],
|
||||
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
|
||||
}
|
||||
elif item.get("type") == "file":
|
||||
if (
|
||||
item.get("context") == "full"
|
||||
@@ -736,7 +775,6 @@ def get_sources_from_items(
|
||||
sources.append(source)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user