Add batching

This commit is contained in:
Gabriel Ecegi
2024-12-13 15:29:43 +01:00
parent bfdbb2df69
commit f2e2b59c18
2 changed files with 182 additions and 14 deletions

View File

@@ -2,16 +2,14 @@
import json
import logging
import mimetypes
import os
import shutil
import uuid
from datetime import datetime
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
from typing import List, Optional
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import tiktoken
@@ -52,7 +50,7 @@ from open_webui.apps.retrieval.utils import (
query_doc_with_hybrid_search,
)
from open_webui.apps.webui.models.files import Files
from open_webui.apps.webui.models.files import FileModel, Files
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
KAGI_SEARCH_API_KEY,
@@ -64,7 +62,6 @@ from open_webui.config import (
CONTENT_EXTRACTION_ENGINE,
CORS_ALLOW_ORIGIN,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENV,
@@ -86,7 +83,6 @@ from open_webui.config import (
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
DEFAULT_RAG_TEMPLATE,
RAG_TEMPLATE,
RAG_TOP_K,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@@ -118,10 +114,7 @@ from open_webui.env import (
DOCKER,
)
from open_webui.utils.misc import (
calculate_sha256,
calculate_sha256_string,
extract_folders_after_data_docs,
sanitize_filename,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
@@ -1047,6 +1040,106 @@ def process_file(
)
class BatchProcessFilesForm(BaseModel):
files: List[FileModel]
collection_name: str
class BatchProcessFilesResult(BaseModel):
file_id: str
status: str
error: Optional[str] = None
class BatchProcessFilesResponse(BaseModel):
results: List[BatchProcessFilesResult]
errors: List[BatchProcessFilesResult]
@app.post("/process/files/batch")
def process_files_batch(
form_data: BatchProcessFilesForm,
user=Depends(get_verified_user),
) -> BatchProcessFilesResponse:
"""
Process a batch of files and save them to the vector database.
"""
results: List[BatchProcessFilesResult] = []
errors: List[BatchProcessFilesResult] = []
collection_name = form_data.collection_name
# Prepare all documents first
all_docs: List[Document] = []
for file_request in form_data.files:
try:
file = Files.get_file_by_id(file_request.file_id)
if not file:
log.error(f"process_files_batch: File {file_request.file_id} not found")
raise ValueError(f"File {file_request.file_id} not found")
text_content = file_request.content
docs: List[Document] = [
Document(
page_content=text_content.replace("<br/>", "\n"),
metadata={
**file.meta,
"name": file_request.filename,
"created_by": file.user_id,
"file_id": file.id,
"source": file_request.filename,
},
)
]
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
Files.update_file_data_by_id(file.id, {"content": text_content})
all_docs.extend(docs)
results.append(BatchProcessFilesResult(
file_id=file.id,
status="prepared"
))
except Exception as e:
log.error(f"process_files_batch: Error processing file {file_request.file_id}: {str(e)}")
errors.append(BatchProcessFilesResult(
file_id=file_request.file_id,
status="failed",
error=str(e)
))
# Save all documents in one batch
if all_docs:
try:
save_docs_to_vector_db(
docs=all_docs,
collection_name=collection_name,
add=True
)
# Update all files with collection name
for result in results:
Files.update_file_metadata_by_id(
result.file_id,
{"collection_name": collection_name}
)
result.status = "completed"
except Exception as e:
log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
for result in results:
result.status = "failed"
errors.append(BatchProcessFilesResult(
file_id=result.file_id,
error=str(e)
))
return BatchProcessFilesResponse(
results=results,
errors=errors
)
class ProcessTextForm(BaseModel):
name: str
content: str
@@ -1509,3 +1602,4 @@ if ENV == "dev":
@app.get("/ef/{text}")
async def get_embeddings_text(text: str):
return {"result": app.state.EMBEDDING_FUNCTION(text)}