mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
feat: add websearch endpoint to RAG API
fix: google PSE endpoint uses GET fix: google PSE returns link, not url fix: serper wrong field
This commit is contained in:
@@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
import os, shutil, logging, re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Union, Sequence
|
||||
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
@@ -58,6 +58,7 @@ from apps.rag.utils import (
|
||||
query_doc_with_hybrid_search,
|
||||
query_collection,
|
||||
query_collection_with_hybrid_search,
|
||||
search_web,
|
||||
)
|
||||
|
||||
from utils.misc import (
|
||||
@@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
|
||||
url: str
|
||||
|
||||
|
||||
class SearchForm(CollectionNameForm):
|
||||
query: str
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {
|
||||
@@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
|
||||
)
|
||||
|
||||
|
||||
def get_web_loader(url: str):
|
||||
def get_web_loader(url: Union[str, Sequence[str]]):
|
||||
# Check if the URL is valid
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
if not validate_url(url):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
if not ENABLE_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
# Get IPv4 and IPv6 addresses
|
||||
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||
# Check if any of the resolved addresses are private
|
||||
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||
for ip in ipv4_addresses:
|
||||
if validators.ipv4(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
for ip in ipv6_addresses:
|
||||
if validators.ipv6(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return WebBaseLoader(url)
|
||||
|
||||
|
||||
def validate_url(url: Union[str, Sequence[str]]):
|
||||
if isinstance(url, str):
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
if not ENABLE_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
# Get IPv4 and IPv6 addresses
|
||||
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
|
||||
# Check if any of the resolved addresses are private
|
||||
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
|
||||
for ip in ipv4_addresses:
|
||||
if validators.ipv4(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
for ip in ipv6_addresses:
|
||||
if validators.ipv6(ip, private=True):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return True
|
||||
elif isinstance(url, Sequence):
|
||||
return all(validate_url(u) for u in url)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
@@ -537,6 +553,32 @@ def resolve_hostname(hostname):
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
@app.post("/websearch")
|
||||
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
|
||||
try:
|
||||
web_results = search_web(form_data.query)
|
||||
urls = [result.link for result in web_results]
|
||||
loader = get_web_loader(urls)
|
||||
data = loader.load()
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == "":
|
||||
collection_name = calculate_sha256_string(form_data.query)[:63]
|
||||
|
||||
store_data_in_vector_db(data, collection_name, overwrite=True)
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"filenames": urls,
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
|
||||
Reference in New Issue
Block a user