feat: save UI config changes to config.json

This commit is contained in:
Jun Siang Cheah
2024-05-10 13:36:10 +08:00
parent 9a95767062
commit 058eb76568
11 changed files with 611 additions and 336 deletions

View File

@@ -93,6 +93,8 @@ from config import (
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
config_set,
config_get,
)
from constants import ERROR_MESSAGES
@@ -133,7 +135,7 @@ def update_embedding_model(
embedding_model: str,
update_model: bool = False,
):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
if embedding_model and config_get(app.state.RAG_EMBEDDING_ENGINE) == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
@@ -158,22 +160,22 @@ def update_reranking_model(
update_embedding_model(
app.state.RAG_EMBEDDING_MODEL,
config_get(app.state.RAG_EMBEDDING_MODEL),
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
app.state.RAG_RERANKING_MODEL,
config_get(app.state.RAG_RERANKING_MODEL),
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
config_get(app.state.RAG_EMBEDDING_ENGINE),
config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
config_get(app.state.OPENAI_API_KEY),
config_get(app.state.OPENAI_API_BASE_URL),
)
origins = ["*"]
@@ -200,12 +202,12 @@ class UrlForm(CollectionNameForm):
async def get_status():
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL,
"chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
"template": config_get(app.state.RAG_TEMPLATE),
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
}
@@ -213,18 +215,21 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
"url": config_get(app.state.OPENAI_API_BASE_URL),
"key": config_get(app.state.OPENAI_API_KEY),
},
}
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
return {
"status": True,
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
}
class OpenAIConfigForm(BaseModel):
@@ -246,31 +251,31 @@ async def update_embedding_config(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
config_set(app.state.RAG_EMBEDDING_ENGINE, form_data.embedding_engine)
config_set(app.state.RAG_EMBEDDING_MODEL, form_data.embedding_model)
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if config_get(app.state.RAG_EMBEDDING_ENGINE) in ["ollama", "openai"]:
if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key
config_set(app.state.OPENAI_API_BASE_URL, form_data.openai_config.url)
config_set(app.state.OPENAI_API_KEY, form_data.openai_config.key)
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
update_embedding_model(config_get(app.state.RAG_EMBEDDING_MODEL), True)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
config_get(app.state.RAG_EMBEDDING_ENGINE),
config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
config_get(app.state.OPENAI_API_KEY),
config_get(app.state.OPENAI_API_BASE_URL),
)
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_engine": config_get(app.state.RAG_EMBEDDING_ENGINE),
"embedding_model": config_get(app.state.RAG_EMBEDDING_MODEL),
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
"url": config_get(app.state.OPENAI_API_BASE_URL),
"key": config_get(app.state.OPENAI_API_KEY),
},
}
except Exception as e:
@@ -293,13 +298,13 @@ async def update_reranking_config(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
config_set(app.state.RAG_RERANKING_MODEL, form_data.reranking_model)
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
update_reranking_model(config_get(app.state.RAG_RERANKING_MODEL), True)
return {
"status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL,
"reranking_model": config_get(app.state.RAG_RERANKING_MODEL),
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
@@ -313,14 +318,16 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": config_get(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
),
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -345,50 +352,69 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images
if form_data.pdf_extract_images != None
else app.state.PDF_EXTRACT_IMAGES
config_set(
app.state.PDF_EXTRACT_IMAGES,
(
form_data.pdf_extract_images
if form_data.pdf_extract_images is not None
else config_get(app.state.PDF_EXTRACT_IMAGES)
),
)
app.state.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
config_set(
app.state.CHUNK_SIZE,
(
form_data.chunk.chunk_size
if form_data.chunk is not None
else config_get(app.state.CHUNK_SIZE)
),
)
app.state.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk != None
else app.state.CHUNK_OVERLAP
config_set(
app.state.CHUNK_OVERLAP,
(
form_data.chunk.chunk_overlap
if form_data.chunk is not None
else config_get(app.state.CHUNK_OVERLAP)
),
)
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
config_set(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
(
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION)
),
)
app.state.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube != None
else app.state.YOUTUBE_LOADER_LANGUAGE
config_set(
app.state.YOUTUBE_LOADER_LANGUAGE,
(
form_data.youtube.language
if form_data.youtube is not None
else config_get(app.state.YOUTUBE_LOADER_LANGUAGE)
),
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube != None
if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"pdf_extract_images": config_get(app.state.PDF_EXTRACT_IMAGES),
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"chunk_size": config_get(app.state.CHUNK_SIZE),
"chunk_overlap": config_get(app.state.CHUNK_OVERLAP),
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": config_get(
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
),
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"language": config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
@@ -398,7 +424,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"template": config_get(app.state.RAG_TEMPLATE),
}
@@ -406,10 +432,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
"template": config_get(app.state.RAG_TEMPLATE),
"k": config_get(app.state.TOP_K),
"r": config_get(app.state.RELEVANCE_THRESHOLD),
"hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
}
@@ -424,16 +450,22 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
config_set(
app.state.RAG_TEMPLATE,
form_data.template if form_data.template else RAG_TEMPLATE,
)
config_set(app.state.TOP_K, form_data.k if form_data.k else 4)
config_set(app.state.RELEVANCE_THRESHOLD, form_data.r if form_data.r else 0.0)
config_set(
app.state.ENABLE_RAG_HYBRID_SEARCH,
form_data.hybrid if form_data.hybrid else False,
)
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
"template": config_get(app.state.RAG_TEMPLATE),
"k": config_get(app.state.TOP_K),
"r": config_get(app.state.RELEVANCE_THRESHOLD),
"hybrid": config_get(app.state.ENABLE_RAG_HYBRID_SEARCH),
}
@@ -451,21 +483,25 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
if app.state.ENABLE_RAG_HYBRID_SEARCH:
if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
r=(
form_data.r
if form_data.r
else config_get(app.state.RELEVANCE_THRESHOLD)
),
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
)
except Exception as e:
log.exception(e)
@@ -489,21 +525,25 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
if app.state.ENABLE_RAG_HYBRID_SEARCH:
if config_get(app.state.ENABLE_RAG_HYBRID_SEARCH):
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
r=(
form_data.r
if form_data.r
else config_get(app.state.RELEVANCE_THRESHOLD)
),
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else config_get(app.state.TOP_K),
)
except Exception as e:
@@ -520,8 +560,8 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
language=config_get(app.state.YOUTUBE_LOADER_LANGUAGE),
translation=config_get(app.state.YOUTUBE_LOADER_TRANSLATION),
)
data = loader.load()
@@ -548,7 +588,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
form_data.url,
verify_ssl=config_get(app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION),
)
data = loader.load()
@@ -604,8 +645,8 @@ def resolve_hostname(hostname):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
chunk_size=config_get(app.state.CHUNK_SIZE),
chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
add_start_index=True,
)
@@ -622,8 +663,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
chunk_size=config_get(app.state.CHUNK_SIZE),
chunk_overlap=config_get(app.state.CHUNK_OVERLAP),
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
@@ -646,11 +687,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
config_get(app.state.RAG_EMBEDDING_ENGINE),
config_get(app.state.RAG_EMBEDDING_MODEL),
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
config_get(app.state.OPENAI_API_KEY),
config_get(app.state.OPENAI_API_BASE_URL),
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@@ -724,7 +765,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
]
if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
loader = PyPDFLoader(
file_path, extract_images=config_get(app.state.PDF_EXTRACT_IMAGES)
)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":