mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Merge pull request #5861 from open-webui/projects
feat: knowledge/projects
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
# TODO: Merge this with the webui_app and make it a single app
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -634,9 +636,23 @@ def save_docs_to_vector_db(
|
||||
metadata: Optional[dict] = None,
|
||||
overwrite: bool = False,
|
||||
split: bool = True,
|
||||
add: bool = False,
|
||||
) -> bool:
|
||||
log.info(f"save_docs_to_vector_db {docs} {collection_name}")
|
||||
|
||||
# Check if entries with the same hash (metadata.hash) already exist
|
||||
if metadata and "hash" in metadata:
|
||||
result = VECTOR_DB_CLIENT.query(
|
||||
collection_name=collection_name,
|
||||
filter={"hash": metadata["hash"]},
|
||||
)
|
||||
|
||||
if result:
|
||||
existing_doc_ids = result.ids[0]
|
||||
if existing_doc_ids:
|
||||
log.info(f"Document with hash {metadata['hash']} already exists")
|
||||
raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
|
||||
|
||||
if split:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=app.state.config.CHUNK_SIZE,
|
||||
@@ -659,42 +675,46 @@ def save_docs_to_vector_db(
|
||||
metadata[key] = str(value)
|
||||
|
||||
try:
|
||||
if overwrite:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
||||
log.info(f"deleting existing collection {collection_name}")
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
||||
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
||||
log.info(f"collection {collection_name} already exists")
|
||||
return True
|
||||
else:
|
||||
embedding_function = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts))
|
||||
)
|
||||
if overwrite:
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
||||
log.info(f"deleting existing collection {collection_name}")
|
||||
|
||||
VECTOR_DB_CLIENT.insert(
|
||||
collection_name=collection_name,
|
||||
items=[
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"vector": embeddings[idx],
|
||||
"metadata": metadatas[idx],
|
||||
}
|
||||
for idx, text in enumerate(texts)
|
||||
],
|
||||
)
|
||||
if add is False:
|
||||
return True
|
||||
|
||||
return True
|
||||
log.info(f"adding to collection {collection_name}")
|
||||
embedding_function = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
list(map(lambda x: x.replace("\n", " "), texts))
|
||||
)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"vector": embeddings[idx],
|
||||
"metadata": metadatas[idx],
|
||||
}
|
||||
for idx, text in enumerate(texts)
|
||||
]
|
||||
|
||||
VECTOR_DB_CLIENT.insert(
|
||||
collection_name=collection_name,
|
||||
items=items,
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return False
|
||||
@@ -702,6 +722,7 @@ def save_docs_to_vector_db(
|
||||
|
||||
class ProcessFileForm(BaseModel):
|
||||
file_id: str
|
||||
content: Optional[str] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
|
||||
@@ -712,42 +733,91 @@ def process_file(
|
||||
):
|
||||
try:
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name is None:
|
||||
with open(file_path, "rb") as f:
|
||||
collection_name = calculate_sha256(f)[:63]
|
||||
collection_name = f"file-{file.id}"
|
||||
|
||||
loader = Loader(
|
||||
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
|
||||
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
|
||||
)
|
||||
docs = loader.load(file.filename, file.meta.get("content_type"), file_path)
|
||||
text_content = " ".join([doc.page_content for doc in docs])
|
||||
log.debug(f"text_content: {text_content}")
|
||||
|
||||
Files.update_files_metadata_by_id(
|
||||
form_data.file_id,
|
||||
{
|
||||
"content": {
|
||||
"text": text_content,
|
||||
}
|
||||
},
|
||||
if form_data.content:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=form_data.content,
|
||||
metadata={
|
||||
"name": file.meta.get("name", file.filename),
|
||||
"created_by": file.user_id,
|
||||
**file.meta,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
text_content = form_data.content
|
||||
elif file.data.get("content", None):
|
||||
docs = [
|
||||
Document(
|
||||
page_content=file.data.get("content", ""),
|
||||
metadata={
|
||||
"name": file.meta.get("name", file.filename),
|
||||
"created_by": file.user_id,
|
||||
**file.meta,
|
||||
},
|
||||
)
|
||||
]
|
||||
text_content = file.data.get("content", "")
|
||||
else:
|
||||
file_path = file.meta.get("path", None)
|
||||
if file_path:
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
)
|
||||
else:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=file.data.get("content", ""),
|
||||
metadata={
|
||||
"name": file.filename,
|
||||
"created_by": file.user_id,
|
||||
**file.meta,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
text_content = " ".join([doc.page_content for doc in docs])
|
||||
|
||||
log.debug(f"text_content: {text_content}")
|
||||
Files.update_file_data_by_id(
|
||||
file.id,
|
||||
{"content": text_content},
|
||||
)
|
||||
|
||||
hash = calculate_sha256_string(text_content)
|
||||
Files.update_file_hash_by_id(file.id, hash)
|
||||
|
||||
try:
|
||||
result = save_docs_to_vector_db(
|
||||
docs,
|
||||
collection_name,
|
||||
{
|
||||
"file_id": form_data.file_id,
|
||||
docs=docs,
|
||||
collection_name=collection_name,
|
||||
metadata={
|
||||
"file_id": file.id,
|
||||
"name": file.meta.get("name", file.filename),
|
||||
"hash": hash,
|
||||
},
|
||||
add=(True if form_data.collection_name else False),
|
||||
)
|
||||
|
||||
if result:
|
||||
Files.update_file_metadata_by_id(
|
||||
file.id,
|
||||
{
|
||||
"collection_name": collection_name,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
@@ -755,10 +825,7 @@ def process_file(
|
||||
"content": text_content,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=e,
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
if "No pandoc was found" in str(e):
|
||||
@@ -769,7 +836,7 @@ def process_file(
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@@ -1183,6 +1250,30 @@ def query_collection_handler(
|
||||
####################################
|
||||
|
||||
|
||||
class DeleteForm(BaseModel):
|
||||
collection_name: str
|
||||
file_id: str
|
||||
|
||||
|
||||
@app.post("/delete")
|
||||
def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
hash = file.hash
|
||||
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=form_data.collection_name,
|
||||
metadata={"hash": hash},
|
||||
)
|
||||
return {"status": True}
|
||||
else:
|
||||
return {"status": False}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return {"status": False}
|
||||
|
||||
|
||||
@app.post("/reset/db")
|
||||
def reset_vector_db(user=Depends(get_admin_user)):
|
||||
VECTOR_DB_CLIENT.reset()
|
||||
|
||||
@@ -319,17 +319,25 @@ def get_rag_context(
|
||||
for file in files:
|
||||
if file.get("context") == "full":
|
||||
context = {
|
||||
"documents": [[file.get("file").get("content")]],
|
||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
||||
}
|
||||
else:
|
||||
context = None
|
||||
|
||||
collection_names = (
|
||||
file["collection_names"]
|
||||
if file["type"] == "collection"
|
||||
else [file["collection_name"]] if file["collection_name"] else []
|
||||
)
|
||||
collection_names = []
|
||||
if file.get("type") == "collection":
|
||||
if file.get("legacy"):
|
||||
collection_names = file.get("collection_names", [])
|
||||
else:
|
||||
collection_names.append(file["id"])
|
||||
elif file.get("collection_name"):
|
||||
collection_names.append(file["collection_name"])
|
||||
elif file.get("id"):
|
||||
if file.get("legacy"):
|
||||
collection_names.append(f"{file['id']}")
|
||||
else:
|
||||
collection_names.append(f"file-{file['id']}")
|
||||
|
||||
collection_names = set(collection_names).difference(extracted_collections)
|
||||
if not collection_names:
|
||||
|
||||
@@ -49,22 +49,52 @@ class ChromaClient:
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.query(
|
||||
query_embeddings=vectors,
|
||||
n_results=limit,
|
||||
)
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.query(
|
||||
query_embeddings=vectors,
|
||||
n_results=limit,
|
||||
)
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
)
|
||||
return None
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: int = 2
|
||||
) -> Optional[GetResult]:
|
||||
# Query the items from the collection based on the filter.
|
||||
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.get(
|
||||
where=filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
@@ -111,11 +141,19 @@ class ChromaClient:
|
||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
||||
)
|
||||
|
||||
def delete(self, collection_name: str, ids: list[str]):
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
collection.delete(ids=ids)
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
elif filter:
|
||||
collection.delete(where=filter)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
|
||||
@@ -135,6 +135,25 @@ class MilvusClient:
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: int = 1
|
||||
) -> Optional[GetResult]:
|
||||
# Query the items from the collection based on the filter.
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
result = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return self._result_to_get_result([result])
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
result = self.client.query(
|
||||
@@ -187,13 +206,32 @@ class MilvusClient:
|
||||
],
|
||||
)
|
||||
|
||||
def delete(self, collection_name: str, ids: list[str]):
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
ids=ids,
|
||||
)
|
||||
if ids:
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
ids=ids,
|
||||
)
|
||||
elif filter:
|
||||
# Convert the filter dictionary to a string using JSON_CONTAINS.
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
|
||||
@@ -10,11 +10,11 @@ from open_webui.apps.webui.routers import (
|
||||
auths,
|
||||
chats,
|
||||
configs,
|
||||
documents,
|
||||
files,
|
||||
functions,
|
||||
memories,
|
||||
models,
|
||||
knowledge,
|
||||
prompts,
|
||||
tools,
|
||||
users,
|
||||
@@ -111,15 +111,15 @@ app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
|
||||
app.include_router(documents.router, prefix="/documents", tags=["documents"])
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
|
||||
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
app.include_router(files.router, prefix="/files", tags=["files"])
|
||||
app.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
app.include_router(functions.router, prefix="/functions", tags=["functions"])
|
||||
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -20,19 +20,29 @@ class File(Base):
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
hash = Column(Text, nullable=True)
|
||||
|
||||
filename = Column(Text)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSONField)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class FileModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
filename: str
|
||||
meta: dict
|
||||
created_at: int # timestamp in epoch
|
||||
hash: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
filename: str
|
||||
data: Optional[dict] = None
|
||||
meta: dict
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
@@ -43,14 +53,21 @@ class FileModel(BaseModel):
|
||||
class FileModelResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
hash: Optional[str] = None
|
||||
|
||||
filename: str
|
||||
data: Optional[dict] = None
|
||||
meta: dict
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FileForm(BaseModel):
|
||||
id: str
|
||||
hash: Optional[str] = None
|
||||
filename: str
|
||||
data: dict = {}
|
||||
meta: dict = {}
|
||||
|
||||
|
||||
@@ -62,6 +79,7 @@ class FilesTable:
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -90,6 +108,13 @@ class FilesTable:
|
||||
with get_db() as db:
|
||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||
|
||||
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in db.query(File).filter(File.id.in_(ids)).all()
|
||||
]
|
||||
|
||||
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
@@ -97,17 +122,38 @@ class FilesTable:
|
||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
||||
]
|
||||
|
||||
def update_files_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
|
||||
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.meta = {**file.meta, **meta}
|
||||
file.hash = hash
|
||||
db.commit()
|
||||
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.data = {**(file.data if file.data else {}), **data}
|
||||
db.commit()
|
||||
return FileModel.model_validate(file)
|
||||
except Exception as e:
|
||||
|
||||
return None
|
||||
|
||||
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.meta = {**(file.meta if file.meta else {}), **meta}
|
||||
db.commit()
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_file_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
|
||||
152
backend/open_webui/apps/webui/models/knowledge.py
Normal file
152
backend/open_webui/apps/webui/models/knowledge.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Knowledge DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Knowledge(Base):
|
||||
__tablename__ = "knowledge"
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
|
||||
name = Column(Text)
|
||||
description = Column(Text)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class KnowledgeModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class KnowledgeResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class KnowledgeForm(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
data: Optional[dict] = None
|
||||
|
||||
|
||||
class KnowledgeUpdateForm(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
|
||||
|
||||
class KnowledgeTable:
|
||||
def insert_new_knowledge(
|
||||
self, user_id: str, form_data: KnowledgeForm
|
||||
) -> Optional[KnowledgeModel]:
|
||||
with get_db() as db:
|
||||
knowledge = KnowledgeModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Knowledge(**knowledge.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return KnowledgeModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_knowledge_items(self) -> list[KnowledgeModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
KnowledgeModel.model_validate(knowledge)
|
||||
for knowledge in db.query(Knowledge)
|
||||
.order_by(Knowledge.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = db.query(Knowledge).filter_by(id=id).first()
|
||||
return KnowledgeModel.model_validate(knowledge) if knowledge else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_knowledge_by_id(
|
||||
self, id: str, form_data: KnowledgeUpdateForm, overwrite: bool = False
|
||||
) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def delete_knowledge_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Knowledge).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Knowledges = KnowledgeTable()
|
||||
@@ -4,13 +4,18 @@ import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.apps.webui.models.files import FileForm, FileModel, Files
|
||||
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -58,6 +63,13 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=id))
|
||||
file = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file.id}")
|
||||
|
||||
if file:
|
||||
return file
|
||||
else:
|
||||
@@ -143,6 +155,55 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Get File Data Content By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/data/content")
|
||||
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
return {"content": file.data.get("content", "")}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Update File Data Content By Id
|
||||
############################
|
||||
|
||||
|
||||
class ContentForm(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/{id}/data/content/update")
|
||||
async def update_file_data_content_by_id(
|
||||
id: str, form_data: ContentForm, user=Depends(get_verified_user)
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=id, content=form_data.content))
|
||||
file = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error processing file: {file.id}")
|
||||
|
||||
return {"content": file.data.get("content", "")}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Get File Content By Id
|
||||
############################
|
||||
@@ -171,34 +232,37 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{id}/content/text")
|
||||
async def get_file_text_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
return {"text": file.meta.get("content", {}).get("text", None)}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
file_path = Path(file.meta["path"])
|
||||
file_path = file.meta.get("path")
|
||||
if file_path:
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
# File path doesn’t exist, return the content as .txt if possible
|
||||
file_content = file.content.get("content", "")
|
||||
file_name = file.filename
|
||||
|
||||
# Create a generator that encodes the file content
|
||||
def generator():
|
||||
yield file_content.encode("utf-8")
|
||||
|
||||
return StreamingResponse(
|
||||
generator(),
|
||||
media_type="text/plain",
|
||||
headers={"Content-Disposition": f"attachment; filename={file_name}"},
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
||||
320
backend/open_webui/apps/webui/routers/knowledge.py
Normal file
320
backend/open_webui/apps/webui/routers/knowledge.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
|
||||
from open_webui.apps.webui.models.knowledge import (
|
||||
Knowledges,
|
||||
KnowledgeUpdateForm,
|
||||
KnowledgeForm,
|
||||
KnowledgeResponse,
|
||||
)
|
||||
from open_webui.apps.webui.models.files import Files, FileModel
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetKnowledgeItems
|
||||
############################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/", response_model=Optional[Union[list[KnowledgeResponse], KnowledgeResponse]]
|
||||
)
|
||||
async def get_knowledge_items(
|
||||
id: Optional[str] = None, user=Depends(get_verified_user)
|
||||
):
|
||||
if id:
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
|
||||
if knowledge:
|
||||
return knowledge
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
return [
|
||||
KnowledgeResponse(**knowledge.model_dump())
|
||||
for knowledge in Knowledges.get_knowledge_items()
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewKnowledge
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[KnowledgeResponse])
|
||||
async def create_new_knowledge(form_data: KnowledgeForm, user=Depends(get_admin_user)):
|
||||
knowledge = Knowledges.insert_new_knowledge(user.id, form_data)
|
||||
|
||||
if knowledge:
|
||||
return knowledge
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_EXISTS,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetKnowledgeById
|
||||
############################
|
||||
|
||||
|
||||
class KnowledgeFilesResponse(KnowledgeResponse):
|
||||
files: list[FileModel]
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
|
||||
async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
|
||||
if knowledge:
|
||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=files,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateKnowledgeById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
|
||||
async def update_knowledge_by_id(
|
||||
id: str,
|
||||
form_data: KnowledgeUpdateForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
|
||||
|
||||
if knowledge:
|
||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=files,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ID_TAKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# AddFileToKnowledge
|
||||
############################
|
||||
|
||||
|
||||
class KnowledgeFileIdForm(BaseModel):
|
||||
file_id: str
|
||||
|
||||
|
||||
@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
|
||||
def add_file_to_knowledge_by_id(
|
||||
id: str,
|
||||
form_data: KnowledgeFileIdForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
if not file.data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_NOT_PROCESSED,
|
||||
)
|
||||
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
if knowledge:
|
||||
data = knowledge.data or {}
|
||||
file_ids = data.get("file_ids", [])
|
||||
|
||||
if form_data.file_id not in file_ids:
|
||||
file_ids.append(form_data.file_id)
|
||||
data["file_ids"] = file_ids
|
||||
|
||||
knowledge = Knowledges.update_knowledge_by_id(
|
||||
id=id, form_data=KnowledgeUpdateForm(data=data)
|
||||
)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=files,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("knowledge"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("file_id"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse])
|
||||
def update_file_from_knowledge_by_id(
|
||||
id: str,
|
||||
form_data: KnowledgeFileIdForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Remove content from the vector database
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
if knowledge:
|
||||
data = knowledge.data or {}
|
||||
file_ids = data.get("file_ids", [])
|
||||
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=files,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# RemoveFileFromKnowledge
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse])
|
||||
def remove_file_from_knowledge_by_id(
|
||||
id: str,
|
||||
form_data: KnowledgeFileIdForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Remove content from the vector database
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
|
||||
)
|
||||
|
||||
result = VECTOR_DB_CLIENT.query(
|
||||
collection_name=knowledge.id,
|
||||
filter={"file_id": form_data.file_id},
|
||||
)
|
||||
|
||||
Files.delete_file_by_id(form_data.file_id)
|
||||
|
||||
if knowledge:
|
||||
data = knowledge.data or {}
|
||||
file_ids = data.get("file_ids", [])
|
||||
|
||||
if form_data.file_id in file_ids:
|
||||
file_ids.remove(form_data.file_id)
|
||||
data["file_ids"] = file_ids
|
||||
|
||||
knowledge = Knowledges.update_knowledge_by_id(
|
||||
id=id, form_data=KnowledgeUpdateForm(data=data)
|
||||
)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=files,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("knowledge"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("file_id"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteKnowledgeById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_knowledge_by_id(id: str, user=Depends(get_admin_user)):
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=id)
|
||||
result = Knowledges.delete_knowledge_by_id(id=id)
|
||||
return result
|
||||
@@ -56,9 +56,6 @@ def run_migrations():
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
run_migrations()
|
||||
|
||||
|
||||
class Config(Base):
|
||||
__tablename__ = "config"
|
||||
|
||||
|
||||
@@ -34,8 +34,8 @@ class ERROR_MESSAGES(str, Enum):
|
||||
|
||||
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
|
||||
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
||||
|
||||
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
||||
|
||||
INVALID_TOKEN = (
|
||||
"Your session has expired or the token is invalid. Please sign in again."
|
||||
)
|
||||
@@ -94,6 +94,11 @@ class ERROR_MESSAGES(str, Enum):
|
||||
lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
|
||||
)
|
||||
|
||||
DUPLICATE_CONTENT = (
|
||||
"Duplicate content detected. Please provide unique content to proceed."
|
||||
)
|
||||
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
||||
|
||||
|
||||
class TASKS(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
from open_webui.env import OPEN_WEBUI_DIR
|
||||
|
||||
alembic_cfg = Config(OPEN_WEBUI_DIR / "alembic.ini")
|
||||
|
||||
# Set the script location dynamically
|
||||
migrations_path = OPEN_WEBUI_DIR / "migrations"
|
||||
alembic_cfg.set_main_option("script_location", str(migrations_path))
|
||||
|
||||
|
||||
def revision(message: str) -> None:
|
||||
command.revision(alembic_cfg, message=message, autogenerate=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_message = input("Enter the revision message: ")
|
||||
revision(input_message)
|
||||
@@ -7,3 +7,9 @@ def get_existing_tables():
|
||||
inspector = Inspector.from_engine(con)
|
||||
tables = set(inspector.get_table_names())
|
||||
return tables
|
||||
|
||||
|
||||
def get_revision_id():
|
||||
import uuid
|
||||
|
||||
return str(uuid.uuid4()).replace("-", "")[:12]
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Add knowledge table
|
||||
|
||||
Revision ID: 6a39f3d8e55c
|
||||
Revises: c0fbf31ca0db
|
||||
Create Date: 2024-10-01 14:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
import json
|
||||
|
||||
|
||||
revision = "6a39f3d8e55c"
|
||||
down_revision = "c0fbf31ca0db"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Creating the 'knowledge' table
|
||||
print("Creating knowledge table")
|
||||
knowledge_table = op.create_table(
|
||||
"knowledge",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
print("Migrating data from document table to knowledge table")
|
||||
# Representation of the existing 'document' table
|
||||
document_table = table(
|
||||
"document",
|
||||
column("collection_name", sa.String()),
|
||||
column("user_id", sa.String()),
|
||||
column("name", sa.String()),
|
||||
column("title", sa.Text()),
|
||||
column("content", sa.Text()),
|
||||
column("timestamp", sa.BigInteger()),
|
||||
)
|
||||
|
||||
# Select all from existing document table
|
||||
documents = op.get_bind().execute(
|
||||
select(
|
||||
document_table.c.collection_name,
|
||||
document_table.c.user_id,
|
||||
document_table.c.name,
|
||||
document_table.c.title,
|
||||
document_table.c.content,
|
||||
document_table.c.timestamp,
|
||||
)
|
||||
)
|
||||
|
||||
# Insert data into knowledge table from document table
|
||||
for doc in documents:
|
||||
op.get_bind().execute(
|
||||
knowledge_table.insert().values(
|
||||
id=doc.collection_name,
|
||||
user_id=doc.user_id,
|
||||
description=doc.name,
|
||||
meta={
|
||||
"legacy": True,
|
||||
"document": True,
|
||||
"tags": json.loads(doc.content or "{}").get("tags", []),
|
||||
},
|
||||
name=doc.title,
|
||||
created_at=doc.timestamp,
|
||||
updated_at=doc.timestamp, # using created_at for both created_at and updated_at in project
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("knowledge")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Update file table
|
||||
|
||||
Revision ID: c0fbf31ca0db
|
||||
Revises: ca81bd47c050
|
||||
Create Date: 2024-09-20 15:26:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c0fbf31ca0db"
|
||||
down_revision: Union[str, None] = "ca81bd47c050"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("file", sa.Column("hash", sa.Text(), nullable=True))
|
||||
op.add_column("file", sa.Column("data", sa.JSON(), nullable=True))
|
||||
op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("file", "updated_at")
|
||||
op.drop_column("file", "data")
|
||||
op.drop_column("file", "hash")
|
||||
Reference in New Issue
Block a user