diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 79965596a2..5a9844c067 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2713,6 +2713,12 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig( ), ) +ENABLE_ASYNC_EMBEDDING = PersistentConfig( + "ENABLE_ASYNC_EMBEDDING", + "rag.enable_async_embedding", + os.environ.get("ENABLE_ASYNC_EMBEDDING", "True").lower() == "true", +) + RAG_EMBEDDING_QUERY_PREFIX = os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None) RAG_EMBEDDING_CONTENT_PREFIX = os.environ.get("RAG_EMBEDDING_CONTENT_PREFIX", None) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8219943408..af8e670a53 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -230,6 +230,7 @@ from open_webui.config import ( RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_BATCH_SIZE, + ENABLE_ASYNC_EMBEDDING, RAG_TOP_K, RAG_TOP_K_RERANKER, RAG_RELEVANCE_THRESHOLD, @@ -884,6 +885,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE +app.state.config.ENABLE_ASYNC_EMBEDDING = ENABLE_ASYNC_EMBEDDING app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index e8dc97209e..b041a00471 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -782,6 +782,7 @@ def get_embedding_function( key, embedding_batch_size, azure_api_version=None, + enable_async=True, ) -> Awaitable: if embedding_engine == "": # Sentence transformers: CPU-bound sync operation @@ -816,16 +817,26 @@ def get_embedding_function( query[i : i + embedding_batch_size] for i in range(0, len(query), embedding_batch_size) ] - log.debug( - f"generate_multiple_async: Processing {len(batches)} batches in parallel" - ) - # Execute all batches in parallel - tasks = [ - embedding_function(batch, prefix=prefix, user=user) - for batch in batches - ] - batch_results = await asyncio.gather(*tasks) + if enable_async: + log.debug( + f"generate_multiple_async: Processing {len(batches)} batches in parallel" + ) + # Execute all batches in parallel + tasks = [ + embedding_function(batch, prefix=prefix, user=user) + for batch in batches + ] + batch_results = await asyncio.gather(*tasks) + else: + log.debug( + f"generate_multiple_async: Processing {len(batches)} batches sequentially" + ) + batch_results = [] + for batch in batches: + batch_results.append( + await embedding_function(batch, prefix=prefix, user=user) + ) # Flatten results embeddings = [] diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index e3191eb0bd..ab93054dab 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -248,6 +248,7 @@ async def get_status(request: Request): "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING, } @@ -258,6 +259,7 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)): "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING, "openai_config": { "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, "key": request.app.state.config.RAG_OPENAI_API_KEY, @@ -297,6 +299,7 @@ class EmbeddingModelUpdateForm(BaseModel): embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 + ENABLE_ASYNC_EMBEDDING: Optional[bool] = True @router.post("/embedding/update") @@ -358,6 +361,10 @@ async def update_embedding_config( form_data.embedding_batch_size ) + request.app.state.config.ENABLE_ASYNC_EMBEDDING = ( + form_data.enable_async_embedding + ) + request.app.state.ef = get_ef( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, @@ -391,6 +398,7 @@ async def update_embedding_config( if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" else None ), + enable_async=request.app.state.config.ENABLE_ASYNC_EMBEDDING, ) return { @@ -398,6 +406,7 @@ async def update_embedding_config( "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE, "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL, "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + "ENABLE_ASYNC_EMBEDDING": request.app.state.config.ENABLE_ASYNC_EMBEDDING, "openai_config": { "url": request.app.state.config.RAG_OPENAI_API_BASE_URL, "key": request.app.state.config.RAG_OPENAI_API_KEY, diff --git a/src/lib/components/admin/Settings/Documents.svelte b/src/lib/components/admin/Settings/Documents.svelte index 5c449fc869..b837308635 100644 --- a/src/lib/components/admin/Settings/Documents.svelte +++ b/src/lib/components/admin/Settings/Documents.svelte @@ -41,6 +41,8 @@ let embeddingEngine = ''; let embeddingModel = ''; let embeddingBatchSize = 1; + let ENABLE_ASYNC_EMBEDDING = true; + let rerankingModel = ''; let OpenAIUrl = ''; @@ -105,6 +107,7 @@ embedding_engine: embeddingEngine, embedding_model: embeddingModel, embedding_batch_size: embeddingBatchSize, + ENABLE_ASYNC_EMBEDDING: ENABLE_ASYNC_EMBEDDING, ollama_config: { key: OllamaKey, url: OllamaUrl @@ -237,6 +240,7 @@ embeddingEngine = embeddingConfig.embedding_engine; embeddingModel = embeddingConfig.embedding_model; embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1; + ENABLE_ASYNC_EMBEDDING = embeddingConfig.ENABLE_ASYNC_EMBEDDING ?? true; OpenAIKey = embeddingConfig.openai_config.key; OpenAIUrl = embeddingConfig.openai_config.url; @@ -927,6 +931,22 @@ /> + +
+
+ + {$i18n.t('Async Embedding Processing')} + +
+
+ +
+
{/if}