feat: add websearch endpoint to RAG API

fix: google PSE endpoint uses GET

fix: google PSE returns link, not url

fix: serper wrong field
This commit is contained in:
Jun Siang Cheah
2024-05-06 16:39:25 +08:00
parent 501ff7a98b
commit 99e4edd364
4 changed files with 76 additions and 38 deletions

View File

@@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging, re
from pathlib import Path
from typing import List
from typing import List, Union, Sequence
from chromadb.utils.batch_utils import create_batches
@@ -58,6 +58,7 @@ from apps.rag.utils import (
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
search_web,
)
from utils.misc import (
@@ -186,6 +187,10 @@ class UrlForm(CollectionNameForm):
url: str
class SearchForm(CollectionNameForm):
query: str
@app.get("/")
async def get_status():
return {
@@ -506,26 +511,37 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
)
def get_web_loader(url: str):
def get_web_loader(url: Union[str, Sequence[str]]):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True
elif isinstance(url, Sequence):
return all(validate_url(u) for u in url)
else:
return False
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
@@ -537,6 +553,32 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
@app.post("/websearch")
def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
try:
web_results = search_web(form_data.query)
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
data = loader.load()
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.query)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(