This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 18:46:29 -08:00
parent 3bda1a8b88
commit ccdf51588e
2 changed files with 70 additions and 68 deletions

View File

@@ -39,6 +39,13 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.routers import (
audio,
images,
@@ -63,35 +70,19 @@ from open_webui.routers import (
users,
utils,
)
from open_webui.retrieval.utils import get_sources_from_files
from open_webui.routers.retrieval import (
get_embedding_function,
update_embedding_model,
update_reranking_model,
get_ef,
get_rf,
)
from open_webui.retrieval.utils import get_sources_from_files
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.internal.db import Session
from open_webui.routers.webui import (
app as webui_app,
generate_function_chat_completion,
get_all_models as get_open_webui_models,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.constants import TASKS
@@ -279,7 +270,7 @@ from open_webui.env import (
OFFLINE_MODE,
)
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
@@ -528,8 +519,8 @@ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.EMBEDDING_FUNCTION = None
app.state.sentence_transformer_ef = None
app.state.sentence_transformer_rf = None
app.state.ef = None
app.state.rf = None
app.state.YOUTUBE_LOADER_TRANSLATION = None
@@ -537,29 +528,34 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.ef,
(
app.state.config.OPENAI_API_BASE_URL
app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
else app.state.config.RAG_OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
else app.state.config.RAG_OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
update_embedding_model(
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
try:
app.state.ef = get_ef(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.rf = get_rf(
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
except Exception as e:
log.error(f"Error updating models: {e}")
pass
########################################
@@ -990,11 +986,11 @@ async def chat_completion_files_handler(
sources = get_sources_from_files(
files=files,
queries=queries,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
k=retrieval_app.state.config.TOP_K,
reranking_function=retrieval_app.state.sentence_transformer_rf,
r=retrieval_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=app.state.config.TOP_K,
reranking_function=app.state.rf,
r=app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts:sources: {sources}")