Merge branch 'open-webui:dev' into dev

This commit is contained in:
smonux
2024-09-28 06:01:26 +02:00
committed by GitHub
50 changed files with 1113 additions and 1091 deletions

View File

@@ -16,37 +16,45 @@ from typing import Optional
import aiohttp
import requests
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import (
GenerateChatCompletionForm,
app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
GenerateChatCompletionForm,
)
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
from open_webui.apps.openai.main import (
app as openai_app,
generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
)
from open_webui.apps.openai.main import get_all_models as get_openai_models
from open_webui.apps.rag.main import app as rag_app
from open_webui.apps.rag.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import app as webui_app
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
@@ -492,11 +500,11 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
k=retrieval_app.state.config.TOP_K,
reranking_function=retrieval_app.state.sentence_transformer_rf,
r=retrieval_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
@@ -609,7 +617,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if prompt is None:
raise Exception("No user message found")
if (
rag_app.state.config.RELEVANCE_THRESHOLD == 0
retrieval_app.state.config.RELEVANCE_THRESHOLD == 0
and context_string.strip() == ""
):
log.debug(
@@ -621,14 +629,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
else:
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
@@ -762,10 +770,22 @@ class PipelineMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers["Authorization"]),
)
try:
user = get_current_user(
request,
get_http_authorization_cred(request.headers["Authorization"]),
)
except KeyError as e:
if len(e.args) > 1:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
else:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Not authenticated"},
)
try:
data = filter_pipeline(data, user)
@@ -838,7 +858,7 @@ async def check_url(request: Request, call_next):
async def update_embedding_function(request: Request, call_next):
response = await call_next(request)
if "/embedding/update" in request.url.path:
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
return response
@@ -866,11 +886,12 @@ app.mount("/openai", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app)
app.mount("/retrieval/api/v1", retrieval_app)
app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
async def get_all_models():
@@ -2055,7 +2076,7 @@ async def get_app_config(request: Request):
"enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM,
**(
{
"enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": webui_app.state.config.ENABLE_MESSAGE_RATING,
@@ -2081,8 +2102,8 @@ async def get_app_config(request: Request):
},
},
"file": {
"max_size": rag_app.state.config.FILE_MAX_SIZE,
"max_count": rag_app.state.config.FILE_MAX_COUNT,
"max_size": retrieval_app.state.config.FILE_MAX_SIZE,
"max_count": retrieval_app.state.config.FILE_MAX_COUNT,
},
"permissions": {**webui_app.state.config.USER_PERMISSIONS},
}
@@ -2154,7 +2175,8 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
try:
async with aiohttp.ClientSession(trust_env=True) as session:
timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response:
@@ -2164,10 +2186,7 @@ async def get_app_latest_release_version():
return {"current": VERSION, "latest": latest_version[1:]}
except aiohttp.ClientError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
)
return {"current": VERSION, "latest": VERSION}
############################