This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 18:46:29 -08:00
parent 3bda1a8b88
commit ccdf51588e
2 changed files with 70 additions and 68 deletions

View File

@@ -97,62 +97,58 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
##########################################
def update_embedding_model(
request: Request,
def get_ef(
engine: str,
embedding_model: str,
auto_update: bool = False,
):
if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "":
ef = None
if embedding_model and engine == "":
from sentence_transformers import SentenceTransformer
try:
request.app.state.sentence_transformer_ef = SentenceTransformer(
ef = SentenceTransformer(
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
except Exception as e:
log.debug(f"Error loading SentenceTransformer: {e}")
request.app.state.sentence_transformer_ef = None
else:
request.app.state.sentence_transformer_ef = None
return ef
def update_reranking_model(
request: Request,
def get_rf(
reranking_model: str,
auto_update: bool = False,
):
rf = None
if reranking_model:
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
try:
from open_webui.retrieval.models.colbert import ColBERT
request.app.state.sentence_transformer_rf = ColBERT(
rf = ColBERT(
get_model_path(reranking_model, auto_update),
env="docker" if DOCKER else None,
)
except Exception as e:
log.error(f"ColBERT: {e}")
request.app.state.sentence_transformer_rf = None
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
raise Exception(ERROR_MESSAGES.DEFAULT(e))
else:
import sentence_transformers
try:
request.app.state.sentence_transformer_rf = (
sentence_transformers.CrossEncoder(
get_model_path(reranking_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
except:
log.error("CrossEncoder error")
request.app.state.sentence_transformer_rf = None
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
else:
request.app.state.sentence_transformer_rf = None
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
return rf
##########################################
@@ -261,12 +257,15 @@ async def update_embedding_config(
form_data.embedding_batch_size
)
update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL)
request.app.state.ef = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
)
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.sentence_transformer_ef,
request.app.state.ef,
(
request.app.state.config.OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@@ -316,7 +315,14 @@ async def update_reranking_config(
try:
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True)
try:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_MODEL,
True,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
return {
"status": True,
@@ -739,7 +745,7 @@ def save_docs_to_vector_db(
embedding_function = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.sentence_transformer_ef,
request.app.state.ef,
(
request.app.state.config.OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@@ -1286,7 +1292,7 @@ def query_doc_handler(
query=form_data.query,
embedding_function=request.app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.sentence_transformer_rf,
reranking_function=request.app.state.rf,
r=(
form_data.r
if form_data.r
@@ -1328,7 +1334,7 @@ def query_collection_handler(
queries=[form_data.query],
embedding_function=request.app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.sentence_transformer_rf,
reranking_function=request.app.state.rf,
r=(
form_data.r
if form_data.r