mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
Merge branch 'dev' into feat/openai-embeddings-batch
This commit is contained in:
@@ -29,6 +29,8 @@ import time
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.users import Users
|
||||
from constants import ERROR_MESSAGES
|
||||
@@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
app.state.MODELS = {}
|
||||
|
||||
|
||||
REQUEST_POOL = []
|
||||
|
||||
|
||||
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
||||
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
||||
# least connections, or least response time for better resource utilization and performance optimization.
|
||||
@@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
|
||||
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
|
||||
|
||||
|
||||
@app.get("/cancel/{request_id}")
|
||||
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
|
||||
if user:
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
return True
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||
|
||||
|
||||
async def fetch_url(url):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
@@ -154,6 +143,45 @@ async def fetch_url(url):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def post_streaming_url(url: str, payload: str):
|
||||
r = None
|
||||
try:
|
||||
session = aiohttp.ClientSession()
|
||||
r = await session.post(url, data=payload)
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(cleanup_response, response=r, session=session),
|
||||
)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
@@ -313,65 +341,7 @@ async def pull_model(
|
||||
# Admin should be able to pull models from any source
|
||||
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
||||
|
||||
def get_request():
|
||||
nonlocal url
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/pull",
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
|
||||
|
||||
|
||||
class PushModelForm(BaseModel):
|
||||
@@ -399,50 +369,9 @@ async def push_model(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.debug(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal url
|
||||
nonlocal r
|
||||
try:
|
||||
|
||||
def stream_content():
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/push",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
|
||||
|
||||
class CreateModelForm(BaseModel):
|
||||
@@ -461,53 +390,9 @@ async def create_model(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal url
|
||||
nonlocal r
|
||||
try:
|
||||
|
||||
def stream_content():
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/create",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
log.debug(f"r: {r}")
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
|
||||
|
||||
class CopyModelForm(BaseModel):
|
||||
@@ -797,66 +682,9 @@ async def generate_completion(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal form_data
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if form_data.stream:
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/generate",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -906,44 +734,77 @@ async def generate_chat_completion(
|
||||
if model_info.params:
|
||||
payload["options"] = {}
|
||||
|
||||
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
||||
payload["options"]["mirostat_eta"] = model_info.params.get(
|
||||
"mirostat_eta", None
|
||||
)
|
||||
payload["options"]["mirostat_tau"] = model_info.params.get(
|
||||
"mirostat_tau", None
|
||||
)
|
||||
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||
if model_info.params.get("mirostat", None):
|
||||
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
||||
|
||||
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||
"repeat_last_n", None
|
||||
)
|
||||
payload["options"]["repeat_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
if model_info.params.get("mirostat_eta", None):
|
||||
payload["options"]["mirostat_eta"] = model_info.params.get(
|
||||
"mirostat_eta", None
|
||||
)
|
||||
|
||||
payload["options"]["temperature"] = model_info.params.get(
|
||||
"temperature", None
|
||||
)
|
||||
payload["options"]["seed"] = model_info.params.get("seed", None)
|
||||
if model_info.params.get("mirostat_tau", None):
|
||||
|
||||
payload["options"]["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
payload["options"]["mirostat_tau"] = model_info.params.get(
|
||||
"mirostat_tau", None
|
||||
)
|
||||
|
||||
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
||||
if model_info.params.get("num_ctx", None):
|
||||
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||
|
||||
payload["options"]["num_predict"] = model_info.params.get(
|
||||
"max_tokens", None
|
||||
)
|
||||
payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
||||
if model_info.params.get("repeat_last_n", None):
|
||||
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||
"repeat_last_n", None
|
||||
)
|
||||
|
||||
payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
||||
if model_info.params.get("frequency_penalty", None):
|
||||
payload["options"]["repeat_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
|
||||
if model_info.params.get("temperature", None):
|
||||
payload["options"]["temperature"] = model_info.params.get(
|
||||
"temperature", None
|
||||
)
|
||||
|
||||
if model_info.params.get("seed", None):
|
||||
payload["options"]["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
if model_info.params.get("stop", None):
|
||||
payload["options"]["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
if model_info.params.get("tfs_z", None):
|
||||
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
||||
|
||||
if model_info.params.get("max_tokens", None):
|
||||
payload["options"]["num_predict"] = model_info.params.get(
|
||||
"max_tokens", None
|
||||
)
|
||||
|
||||
if model_info.params.get("top_k", None):
|
||||
payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
||||
|
||||
if model_info.params.get("top_p", None):
|
||||
payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
||||
|
||||
if model_info.params.get("use_mmap", None):
|
||||
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
|
||||
|
||||
if model_info.params.get("use_mlock", None):
|
||||
payload["options"]["use_mlock"] = model_info.params.get(
|
||||
"use_mlock", None
|
||||
)
|
||||
|
||||
if model_info.params.get("num_thread", None):
|
||||
payload["options"]["num_thread"] = model_info.params.get(
|
||||
"num_thread", None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
@@ -981,67 +842,7 @@ async def generate_chat_completion(
|
||||
|
||||
print(payload)
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if payload.get("stream", None):
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/chat",
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
|
||||
|
||||
|
||||
# TODO: we should update this part once Ollama supports other types
|
||||
@@ -1132,68 +933,7 @@ async def generate_openai_chat_completion(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if payload.get("stream"):
|
||||
yield json.dumps(
|
||||
{"request_id": request_id, "done": False}
|
||||
) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/v1/chat/completions",
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
@@ -1522,7 +1262,7 @@ async def deprecated_proxy(
|
||||
if path == "generate":
|
||||
data = json.loads(body.decode("utf-8"))
|
||||
|
||||
if not ("stream" in data and data["stream"] == False):
|
||||
if data.get("stream", True):
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
elif path == "chat":
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.users import Users
|
||||
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
log.debug(f"merge_models_lists {model_lists}")
|
||||
merged_list = []
|
||||
@@ -228,6 +239,27 @@ async def get_all_models(raw: bool = False):
|
||||
) or not app.state.config.ENABLE_OPENAI_API:
|
||||
models = {"data": []}
|
||||
else:
|
||||
# Check if API KEYS length is same than API URLS length
|
||||
if len(app.state.config.OPENAI_API_KEYS) != len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
# if there are more keys than urls, remove the extra keys
|
||||
if len(app.state.config.OPENAI_API_KEYS) > len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
|
||||
: len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
# if there are more urls than keys, add empty keys
|
||||
else:
|
||||
app.state.config.OPENAI_API_KEYS += [
|
||||
""
|
||||
for _ in range(
|
||||
len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
- len(app.state.config.OPENAI_API_KEYS)
|
||||
)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
||||
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
||||
@@ -426,40 +458,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
session = aiohttp.ClientSession()
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=payload if payload else body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.iter_content(chunk_size=8192),
|
||||
status_code=r.status_code,
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
response_data = r.json()
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
res = await r.json()
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500, detail=error_detail
|
||||
)
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
@@ -9,20 +10,52 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_searxng(query_url: str, query: str, count: int) -> list[SearchResult]:
|
||||
"""Search a SearXNG instance for a query and return the results as a list of SearchResult objects.
|
||||
def search_searxng(query_url: str, query: str, count: int, **kwargs) -> List[SearchResult]:
|
||||
"""
|
||||
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
|
||||
|
||||
The function allows passing additional parameters such as language or time_range to tailor the search result.
|
||||
|
||||
Args:
|
||||
query_url (str): The URL of the SearXNG instance to search. Must contain "<query>" as a placeholder
|
||||
query (str): The query to search for
|
||||
query_url (str): The base URL of the SearXNG server with a placeholder for the query "<query>".
|
||||
query (str): The search term or question to find in the SearXNG database.
|
||||
count (int): The maximum number of results to retrieve from the search.
|
||||
|
||||
Keyword Args:
|
||||
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
|
||||
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
|
||||
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
|
||||
|
||||
Returns:
|
||||
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
|
||||
|
||||
Raise:
|
||||
requests.exceptions.RequestException: If a request error occurs during the search process.
|
||||
"""
|
||||
url = query_url.replace("<query>", query)
|
||||
if "&format=json" not in url:
|
||||
url += "&format=json"
|
||||
log.debug(f"searching {url}")
|
||||
|
||||
# Default values for optional parameters are provided as empty strings or None when not specified.
|
||||
language = kwargs.get('language', 'en-US')
|
||||
time_range = kwargs.get('time_range', '')
|
||||
categories = ''.join(kwargs.get('categories', []))
|
||||
|
||||
r = requests.get(
|
||||
url,
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"pageno": 1,
|
||||
"results_per_page": count,
|
||||
'language': language,
|
||||
'time_range': time_range,
|
||||
'engines': '',
|
||||
'categories': categories,
|
||||
'theme': 'simple',
|
||||
'image_proxy': 0
|
||||
|
||||
}
|
||||
|
||||
log.debug(f"searching {query_url}")
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Accept": "text/html",
|
||||
@@ -30,15 +63,17 @@ def search_searxng(query_url: str, query: str, count: int) -> list[SearchResult]
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
json_response = r.json()
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
||||
)
|
||||
for result in sorted_results[:count]
|
||||
for result in sorted_results
|
||||
]
|
||||
|
||||
@@ -298,6 +298,15 @@ class ChatTable:
|
||||
# .limit(limit).offset(skip)
|
||||
]
|
||||
|
||||
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select()
|
||||
.where(Chat.archived == True)
|
||||
.where(Chat.user_id == user_id)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
]
|
||||
|
||||
def delete_chat_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Chat.delete().where((Chat.id == id))
|
||||
|
||||
@@ -113,6 +113,19 @@ async def get_user_chats(user=Depends(get_current_user)):
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetArchivedChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/all/archived", response_model=List[ChatResponse])
|
||||
async def get_user_chats(user=Depends(get_current_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetAllChatsInDB
|
||||
############################
|
||||
|
||||
Reference in New Issue
Block a user