feat: add RAG_EMBEDDING_OPENAI_BATCH_SIZE to batch multiple embeddings

This commit is contained in:
Jun Siang Cheah
2024-06-02 15:34:31 +01:00
parent 207e25035a
commit 0cb8163321
39 changed files with 112 additions and 19 deletions

View File

@@ -78,6 +78,7 @@ from utils.misc import (
from utils.utils import get_current_user, get_admin_user
from config import (
AppConfig,
ENV,
SRC_LOG_LEVELS,
UPLOAD_DIR,
@@ -114,7 +115,7 @@ from config import (
SERPER_API_KEY,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
AppConfig,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
from constants import ERROR_MESSAGES
@@ -139,6 +140,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@@ -212,6 +214,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
origins = ["*"]
@@ -248,6 +251,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
}
@@ -260,6 +264,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@@ -275,6 +280,7 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
@@ -295,9 +301,14 @@ async def update_embedding_config(
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
form_data.openai_config.batch_size
if form_data.openai_config.batch_size
else 1
)
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -307,6 +318,7 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
return {
@@ -316,6 +328,7 @@ async def update_embedding_config(
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@@ -881,6 +894,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))