mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
Merge pull request #2785 from cheahjs/feat/openai-embeddings-batch
feat: add RAG_EMBEDDING_OPENAI_BATCH_SIZE to batch multiple embeddings
This commit is contained in:
@@ -78,6 +78,7 @@ from utils.misc import (
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
|
||||
from config import (
|
||||
AppConfig,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
@@ -114,7 +115,7 @@ from config import (
|
||||
SERPER_API_KEY,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
AppConfig,
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
@@ -139,6 +140,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
@@ -212,6 +214,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
origins = ["*"]
|
||||
@@ -248,6 +251,7 @@ async def get_status():
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
||||
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
}
|
||||
|
||||
|
||||
@@ -260,6 +264,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -275,6 +280,7 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
@@ -295,9 +301,14 @@ async def update_embedding_config(
|
||||
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
if form_data.openai_config != None:
|
||||
if form_data.openai_config is not None:
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
||||
form_data.openai_config.batch_size
|
||||
if form_data.openai_config.batch_size
|
||||
else 1
|
||||
)
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
@@ -307,6 +318,7 @@ async def update_embedding_config(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -316,6 +328,7 @@ async def update_embedding_config(
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -881,6 +894,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from apps.ollama.main import (
|
||||
generate_ollama_embeddings,
|
||||
@@ -21,17 +21,7 @@ from langchain.retrievers import (
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
CHROMA_CLIENT,
|
||||
SEARXNG_QUERY_URL,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
SERPSTACK_API_KEY,
|
||||
SERPSTACK_HTTPS,
|
||||
SERPER_API_KEY,
|
||||
)
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -209,6 +199,7 @@ def get_embedding_function(
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
@@ -232,7 +223,13 @@ def get_embedding_function(
|
||||
|
||||
def generate_multiple(query, f):
|
||||
if isinstance(query, list):
|
||||
return [f(q) for q in query]
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
else:
|
||||
return f(query)
|
||||
|
||||
@@ -413,8 +410,22 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||
model: str,
|
||||
text: Union[str, list[str]],
|
||||
key: str,
|
||||
url: str = "https://api.openai.com/v1",
|
||||
):
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
@@ -422,12 +433,12 @@ def generate_openai_embeddings(
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
},
|
||||
json={"input": text, "model": model},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return data["data"][0]["embedding"]
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user