chore: format

This commit is contained in:
Timothy Jaeryang Baek
2025-04-12 16:35:11 -07:00
parent 77b25ae36a
commit 91a455a284
61 changed files with 1157 additions and 203 deletions

View File

@@ -76,11 +76,11 @@ def serve(
from open_webui.env import UVICORN_WORKERS # Import the workers setting
uvicorn.run(
open_webui.main.app,
host=host,
port=port,
open_webui.main.app,
host=host,
port=port,
forwarded_allow_ips="*",
workers=UVICORN_WORKERS
workers=UVICORN_WORKERS,
)

View File

@@ -495,4 +495,4 @@ PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
# PROGRESSIVE WEB APP OPTIONS
####################################
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")

View File

@@ -297,7 +297,9 @@ def query_collection_with_hybrid_search(
collection_results = {}
for collection_name in collection_names:
try:
log.debug(f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}")
log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
@@ -619,7 +621,9 @@ def generate_openai_batch_embeddings(
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}")
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
@@ -662,7 +666,9 @@ def generate_ollama_batch_embeddings(
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}")
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix

View File

@@ -624,10 +624,7 @@ def transcribe(request: Request, file_path):
elif request.app.state.config.STT_ENGINE == "azure":
# Check file exists and size
if not os.path.exists(file_path):
raise HTTPException(
status_code=400,
detail="Audio file not found"
)
raise HTTPException(status_code=400, detail="Audio file not found")
# Check file size (Azure has a larger limit of 200MB)
file_size = os.path.getsize(file_path)
@@ -643,11 +640,22 @@ def transcribe(request: Request, file_path):
# IF NO LOCALES, USE DEFAULTS
if len(locales) < 2:
locales = ['en-US', 'es-ES', 'es-MX', 'fr-FR', 'hi-IN',
'it-IT','de-DE', 'en-GB', 'en-IN', 'ja-JP',
'ko-KR', 'pt-BR', 'zh-CN']
locales = ','.join(locales)
locales = [
"en-US",
"es-ES",
"es-MX",
"fr-FR",
"hi-IN",
"it-IT",
"de-DE",
"en-GB",
"en-IN",
"ja-JP",
"ko-KR",
"pt-BR",
"zh-CN",
]
locales = ",".join(locales)
if not api_key or not region:
raise HTTPException(
@@ -658,22 +666,26 @@ def transcribe(request: Request, file_path):
r = None
try:
# Prepare the request
data = {'definition': json.dumps({
"locales": locales.split(','),
"diarization": {"maxSpeakers": 3,"enabled": True}
} if locales else {}
)
data = {
"definition": json.dumps(
{
"locales": locales.split(","),
"diarization": {"maxSpeakers": 3, "enabled": True},
}
if locales
else {}
)
}
url = f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
# Use context manager to ensure file is properly closed
with open(file_path, 'rb') as audio_file:
with open(file_path, "rb") as audio_file:
r = requests.post(
url=url,
files={'audio': audio_file},
files={"audio": audio_file},
data=data,
headers={
'Ocp-Apim-Subscription-Key': api_key,
"Ocp-Apim-Subscription-Key": api_key,
},
)
@@ -681,11 +693,11 @@ def transcribe(request: Request, file_path):
response = r.json()
# Extract transcript from response
if not response.get('combinedPhrases'):
if not response.get("combinedPhrases"):
raise ValueError("No transcription found in response")
# Get the full transcript from combinedPhrases
transcript = response['combinedPhrases'][0].get('text', '').strip()
transcript = response["combinedPhrases"][0].get("text", "").strip()
if not transcript:
raise ValueError("Empty transcript in response")
@@ -718,7 +730,7 @@ def transcribe(request: Request, file_path):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, 'status_code', 500) if r else 500,
status_code=getattr(r, "status_code", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)

View File

@@ -159,7 +159,6 @@ async def create_new_knowledge(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_EXISTS,
)
############################
@@ -168,20 +167,17 @@ async def create_new_knowledge(
@router.post("/reindex", response_model=bool)
async def reindex_knowledge_files(
request: Request,
user=Depends(get_verified_user)
):
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
knowledge_bases = Knowledges.get_knowledge_bases()
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
for knowledge_base in knowledge_bases:
try:
files = Files.get_files_by_ids(knowledge_base.data.get("file_ids", []))
@@ -195,34 +191,40 @@ async def reindex_knowledge_files(
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error deleting vector DB collection"
detail=f"Error deleting vector DB collection",
)
failed_files = []
for file in files:
try:
process_file(
request,
ProcessFileForm(file_id=file.id, collection_name=knowledge_base.id),
ProcessFileForm(
file_id=file.id, collection_name=knowledge_base.id
),
user=user,
)
except Exception as e:
log.error(f"Error processing file {file.filename} (ID: {file.id}): {str(e)}")
log.error(
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
)
failed_files.append({"file_id": file.id, "error": str(e)})
continue
except Exception as e:
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing knowledge base"
detail=f"Error processing knowledge base",
)
if failed_files:
log.warning(f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}")
log.warning(
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
)
for failed in failed_files:
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
log.info("Reindexing completed successfully")
return True
@@ -742,6 +744,3 @@ def add_files_to_knowledge_batch(
return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
)

View File

@@ -37,7 +37,11 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS:
mgr = socketio.AsyncRedisManager(get_sentinel_url_from_env(WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT))
mgr = socketio.AsyncRedisManager(
get_sentinel_url_from_env(
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
)
else:
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
sio = socketio.AsyncServer(

View File

@@ -52,5 +52,7 @@ def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env):
auth_part = ""
if username or password:
auth_part = f"{username}:{password}@"
hosts_part = ",".join(f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(","))
hosts_part = ",".join(
f"{host}:{sentinel_port_env}" for host in sentinel_hosts_env.split(",")
)
return f"redis+sentinel://{auth_part}{hosts_part}/{redis_config['db']}/{redis_config['service']}"