mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 20:07:49 +01:00
feat: add RAG_EMBEDDING_OPENAI_BATCH_SIZE to batch multiple embeddings
This commit is contained in:
@@ -78,6 +78,7 @@ from utils.misc import (
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
|
||||
from config import (
|
||||
AppConfig,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
@@ -114,7 +115,7 @@ from config import (
|
||||
SERPER_API_KEY,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
AppConfig,
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
@@ -139,6 +140,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
@@ -212,6 +214,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
origins = ["*"]
|
||||
@@ -248,6 +251,7 @@ async def get_status():
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
||||
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
}
|
||||
|
||||
|
||||
@@ -260,6 +264,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -275,6 +280,7 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
@@ -295,9 +301,14 @@ async def update_embedding_config(
|
||||
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
if form_data.openai_config != None:
|
||||
if form_data.openai_config is not None:
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
||||
form_data.openai_config.batch_size
|
||||
if form_data.openai_config.batch_size
|
||||
else 1
|
||||
)
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
@@ -307,6 +318,7 @@ async def update_embedding_config(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -316,6 +328,7 @@ async def update_embedding_config(
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -881,6 +894,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
|
||||
Reference in New Issue
Block a user