mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 19:37:47 +01:00
Merge branch 'dev' into dev
This commit is contained in:
@@ -709,8 +709,10 @@ def save_docs_to_vector_db(
|
||||
if overwrite:
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
||||
log.info(f"deleting existing collection {collection_name}")
|
||||
|
||||
if add is False:
|
||||
elif add is False:
|
||||
log.info(
|
||||
f"collection {collection_name} already exists, overwrite is False and add is False"
|
||||
)
|
||||
return True
|
||||
|
||||
log.info(f"adding to collection {collection_name}")
|
||||
@@ -823,7 +825,7 @@ def process_file(
|
||||
# Process the file and save the content
|
||||
# Usage: /files/
|
||||
|
||||
file_path = file.meta.get("path", None)
|
||||
file_path = file.path
|
||||
if file_path:
|
||||
loader = Loader(
|
||||
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
|
||||
@@ -385,6 +385,8 @@ def get_rag_context(
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
if context:
|
||||
if "data" in file:
|
||||
del file["data"]
|
||||
relevant_contexts.append({**context, "file": file})
|
||||
|
||||
contexts = []
|
||||
@@ -401,7 +403,6 @@ def get_rag_context(
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
contexts.append(
|
||||
((", ".join(file_names) + ":\n\n") if file_names else "")
|
||||
+ "\n\n".join(
|
||||
@@ -410,13 +411,14 @@ def get_rag_context(
|
||||
)
|
||||
|
||||
if "metadatas" in context:
|
||||
citations.append(
|
||||
{
|
||||
"source": context["file"],
|
||||
"document": context["documents"][0],
|
||||
"metadata": context["metadatas"][0],
|
||||
}
|
||||
)
|
||||
citation = {
|
||||
"source": context["file"],
|
||||
"document": context["documents"][0],
|
||||
"metadata": context["metadatas"][0],
|
||||
}
|
||||
if "distances" in context and context["distances"]:
|
||||
citation["distances"] = context["distances"][0]
|
||||
citations.append(citation)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
|
||||
@@ -109,7 +109,10 @@ class ChromaClient:
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(name=collection_name)
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
@@ -127,7 +130,10 @@ class ChromaClient:
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(name=collection_name)
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
|
||||
@@ -9,6 +9,7 @@ from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.apps.webui.routers import (
|
||||
auths,
|
||||
chats,
|
||||
folders,
|
||||
configs,
|
||||
files,
|
||||
functions,
|
||||
@@ -119,6 +120,7 @@ app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
app.include_router(folders.router, prefix="/folders", tags=["folders"])
|
||||
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
|
||||
@@ -344,7 +346,7 @@ async def generate_function_chat_completion(form_data, user):
|
||||
pipe = function_module.pipe
|
||||
params = get_function_params(function_module, form_data, user, extra_params)
|
||||
|
||||
if form_data["stream"]:
|
||||
if form_data.get("stream", False):
|
||||
|
||||
async def stream_content():
|
||||
try:
|
||||
|
||||
@@ -33,6 +33,7 @@ class Chat(Base):
|
||||
pinned = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
meta = Column(JSON, server_default="{}")
|
||||
folder_id = Column(Text, nullable=True)
|
||||
|
||||
|
||||
class ChatModel(BaseModel):
|
||||
@@ -51,6 +52,7 @@ class ChatModel(BaseModel):
|
||||
pinned: Optional[bool] = False
|
||||
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
####################
|
||||
@@ -62,6 +64,12 @@ class ChatForm(BaseModel):
|
||||
chat: dict
|
||||
|
||||
|
||||
class ChatImportForm(ChatForm):
|
||||
meta: Optional[dict] = {}
|
||||
pinned: Optional[bool] = False
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatTitleMessagesForm(BaseModel):
|
||||
title: str
|
||||
messages: list[dict]
|
||||
@@ -82,6 +90,7 @@ class ChatResponse(BaseModel):
|
||||
archived: bool
|
||||
pinned: Optional[bool] = False
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatTitleIdResponse(BaseModel):
|
||||
@@ -116,6 +125,35 @@ class ChatTable:
|
||||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
|
||||
def import_chat(
|
||||
self, user_id: str, form_data: ChatImportForm
|
||||
) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"title": (
|
||||
form_data.chat["title"]
|
||||
if "title" in form_data.chat
|
||||
else "New Chat"
|
||||
),
|
||||
"chat": form_data.chat,
|
||||
"meta": form_data.meta,
|
||||
"pinned": form_data.pinned,
|
||||
"folder_id": form_data.folder_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
result = Chat(**chat.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
@@ -254,7 +292,7 @@ class ChatTable:
|
||||
limit: int = 50,
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
|
||||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
@@ -276,7 +314,7 @@ class ChatTable:
|
||||
limit: Optional[int] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
|
||||
if not include_archived:
|
||||
@@ -444,7 +482,18 @@ class ChatTable:
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if tag_ids:
|
||||
if "none" in tag_ids:
|
||||
query = query.filter(
|
||||
text(
|
||||
"""
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.meta, '$.tags') AS tag
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
elif tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
@@ -482,7 +531,18 @@ class ChatTable:
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if tag_ids:
|
||||
if "none" in tag_ids:
|
||||
query = query.filter(
|
||||
text(
|
||||
"""
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements_text(Chat.meta->'tags') AS tag
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
elif tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
@@ -512,6 +572,49 @@ class ChatTable:
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_folder_ids_and_user_id(
|
||||
self, folder_ids: list[str], user_id: str
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter(
|
||||
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
|
||||
)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def update_chat_folder_id_by_id_and_user_id(
|
||||
self, id: str, user_id: str, folder_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.folder_id = folder_id
|
||||
chat.updated_at = int(time.time())
|
||||
chat.pinned = False
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
@@ -673,6 +776,18 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chats_by_user_id_and_folder_id(
|
||||
self, user_id: str, folder_id: str
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -17,14 +17,15 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = "file"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
hash = Column(Text, nullable=True)
|
||||
|
||||
filename = Column(Text)
|
||||
path = Column(Text, nullable=True)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSONField)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
@@ -38,8 +39,10 @@ class FileModel(BaseModel):
|
||||
hash: Optional[str] = None
|
||||
|
||||
filename: str
|
||||
path: Optional[str] = None
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: dict
|
||||
meta: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
@@ -82,6 +85,7 @@ class FileForm(BaseModel):
|
||||
id: str
|
||||
hash: Optional[str] = None
|
||||
filename: str
|
||||
path: str
|
||||
data: dict = {}
|
||||
meta: dict = {}
|
||||
|
||||
|
||||
271
backend/open_webui/apps/webui/models/folders.py
Normal file
271
backend/open_webui/apps/webui/models/folders.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Folder DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Folder(Base):
|
||||
__tablename__ = "folder"
|
||||
id = Column(Text, primary_key=True)
|
||||
parent_id = Column(Text, nullable=True)
|
||||
user_id = Column(Text)
|
||||
name = Column(Text)
|
||||
items = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
is_expanded = Column(Boolean, default=False)
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class FolderModel(BaseModel):
|
||||
id: str
|
||||
parent_id: Optional[str] = None
|
||||
user_id: str
|
||||
name: str
|
||||
items: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
is_expanded: bool = False
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FolderForm(BaseModel):
|
||||
name: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FolderTable:
|
||||
def insert_new_folder(
|
||||
self, user_id: str, name: str, parent_id: Optional[str] = None
|
||||
) -> Optional[FolderModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
folder = FolderModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"parent_id": parent_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
result = Folder(**folder.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return FolderModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_children_folders_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folders = []
|
||||
|
||||
def get_children(folder):
|
||||
children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id
|
||||
)
|
||||
for child in children:
|
||||
get_children(child)
|
||||
folders.append(child)
|
||||
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
get_children(folder)
|
||||
return folders
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder).filter_by(user_id=user_id).all()
|
||||
]
|
||||
|
||||
def get_folder_by_parent_id_and_user_id_and_name(
|
||||
self, parent_id: Optional[str], user_id: str, name: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Check if folder exists
|
||||
folder = (
|
||||
db.query(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.filter(Folder.name.ilike(name))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
|
||||
return None
|
||||
|
||||
def get_folders_by_parent_id_and_user_id(
|
||||
self, parent_id: Optional[str], user_id: str
|
||||
) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.all()
|
||||
]
|
||||
|
||||
def update_folder_parent_id_by_id_and_user_id(
|
||||
self,
|
||||
id: str,
|
||||
user_id: str,
|
||||
parent_id: str,
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
folder.parent_id = parent_id
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def update_folder_name_by_id_and_user_id(
|
||||
self, id: str, user_id: str, name: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
existing_folder = (
|
||||
db.query(Folder)
|
||||
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_folder:
|
||||
return None
|
||||
|
||||
folder.name = name
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def update_folder_is_expanded_by_id_and_user_id(
|
||||
self, id: str, user_id: str, is_expanded: bool
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
folder.is_expanded = is_expanded
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
if not folder:
|
||||
return False
|
||||
|
||||
# Delete all chats in the folder
|
||||
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
|
||||
|
||||
# Delete all children folders
|
||||
def delete_children(folder):
|
||||
folder_children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id
|
||||
)
|
||||
for folder_child in folder_children:
|
||||
Chats.delete_chats_by_user_id_and_folder_id(
|
||||
user_id, folder_child.id
|
||||
)
|
||||
delete_children(folder_child)
|
||||
|
||||
folder = db.query(Folder).filter_by(id=folder_child.id).first()
|
||||
db.delete(folder)
|
||||
db.commit()
|
||||
|
||||
delete_children(folder)
|
||||
db.delete(folder)
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_folder: {e}")
|
||||
return False
|
||||
|
||||
|
||||
Folders = FolderTable()
|
||||
@@ -4,11 +4,13 @@ from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.models.chats import (
|
||||
ChatForm,
|
||||
ChatImportForm,
|
||||
ChatResponse,
|
||||
Chats,
|
||||
ChatTitleIdResponse,
|
||||
)
|
||||
from open_webui.apps.webui.models.tags import TagModel, Tags
|
||||
from open_webui.apps.webui.models.folders import Folders
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@@ -99,6 +101,34 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ImportChat
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/import", response_model=Optional[ChatResponse])
|
||||
async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)):
|
||||
try:
|
||||
chat = Chats.import_chat(user.id, form_data)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
for tag_id in tags:
|
||||
tag_id = tag_id.replace(" ", "_").lower()
|
||||
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
|
||||
if (
|
||||
tag_id != "none"
|
||||
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
|
||||
):
|
||||
Tags.insert_new_tag(tag_name, user.id)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChats
|
||||
############################
|
||||
@@ -133,6 +163,26 @@ async def search_user_chats(
|
||||
return chat_list
|
||||
|
||||
|
||||
############################
|
||||
# GetChatsByFolderId
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/folder/{folder_id}", response_model=list[ChatResponse])
|
||||
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)):
|
||||
folder_ids = [folder_id]
|
||||
children_folders = Folders.get_children_folders_by_id_and_user_id(
|
||||
folder_id, user.id
|
||||
)
|
||||
if children_folders:
|
||||
folder_ids.extend([folder.id for folder in children_folders])
|
||||
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_chats_by_folder_ids_and_user_id(folder_ids, user.id)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetPinnedChats
|
||||
############################
|
||||
@@ -491,6 +541,31 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChatFolderIdById
|
||||
############################
|
||||
|
||||
|
||||
class ChatFolderIdForm(BaseModel):
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/{id}/folder", response_model=Optional[ChatResponse])
|
||||
async def update_chat_folder_id_by_id(
|
||||
id: str, form_data: ChatFolderIdForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.update_chat_folder_id_by_id_and_user_id(
|
||||
id, user.id, form_data.folder_id
|
||||
)
|
||||
return ChatResponse(**chat.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChatTagsById
|
||||
############################
|
||||
@@ -522,6 +597,12 @@ async def add_tag_by_id_and_tag_name(
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = form_data.name.replace(" ", "_").lower()
|
||||
|
||||
if tag_id == "none":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
|
||||
)
|
||||
|
||||
print(tags, tag_id)
|
||||
if tag_id not in tags:
|
||||
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
|
||||
@@ -57,11 +57,11 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
||||
**{
|
||||
"id": id,
|
||||
"filename": filename,
|
||||
"path": file_path,
|
||||
"meta": {
|
||||
"name": name,
|
||||
"content_type": file.content_type,
|
||||
"size": len(contents),
|
||||
"path": file_path,
|
||||
},
|
||||
}
|
||||
),
|
||||
@@ -218,7 +218,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
file_path = Path(file.meta["path"])
|
||||
file_path = Path(file.path)
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
@@ -244,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
file_path = file.meta.get("path")
|
||||
file_path = file.path
|
||||
if file_path:
|
||||
file_path = Path(file_path)
|
||||
|
||||
|
||||
251
backend/open_webui/apps/webui/routers/folders.py
Normal file
251
backend/open_webui/apps/webui/routers/folders.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import mimetypes
|
||||
|
||||
|
||||
from open_webui.apps.webui.models.folders import (
|
||||
FolderForm,
|
||||
FolderModel,
|
||||
Folders,
|
||||
)
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
############################
|
||||
# Get Folders
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[FolderModel])
|
||||
async def get_folders(user=Depends(get_verified_user)):
|
||||
folders = Folders.get_folders_by_user_id(user.id)
|
||||
|
||||
return [
|
||||
{
|
||||
**folder.model_dump(),
|
||||
"items": {
|
||||
"chats": [
|
||||
{"title": chat.title, "id": chat.id}
|
||||
for chat in Chats.get_chats_by_folder_id_and_user_id(
|
||||
folder.id, user.id
|
||||
)
|
||||
]
|
||||
},
|
||||
}
|
||||
for folder in folders
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# Create Folder
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/")
|
||||
def create_folder(form_data: FolderForm, user=Depends(get_verified_user)):
|
||||
folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
None, user.id, form_data.name
|
||||
)
|
||||
|
||||
if folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
|
||||
)
|
||||
|
||||
try:
|
||||
folder = Folders.insert_new_folder(user.id, form_data.name)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error("Error creating folder")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating folder"),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Get Folders By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[FolderModel])
|
||||
async def get_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
return folder
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Update Folder Name By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update")
|
||||
async def update_folder_name_by_id(
|
||||
id: str, form_data: FolderForm, user=Depends(get_verified_user)
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
folder.parent_id, user.id, form_data.name
|
||||
)
|
||||
if existing_folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
|
||||
)
|
||||
|
||||
try:
|
||||
folder = Folders.update_folder_name_by_id_and_user_id(
|
||||
id, user.id, form_data.name
|
||||
)
|
||||
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error updating folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Update Folder Parent Id By Id
|
||||
############################
|
||||
|
||||
|
||||
class FolderParentIdForm(BaseModel):
|
||||
parent_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/{id}/update/parent")
|
||||
async def update_folder_parent_id_by_id(
|
||||
id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user)
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
form_data.parent_id, user.id, folder.name
|
||||
)
|
||||
|
||||
if existing_folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Folder already exists"),
|
||||
)
|
||||
|
||||
try:
|
||||
folder = Folders.update_folder_parent_id_by_id_and_user_id(
|
||||
id, user.id, form_data.parent_id
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error updating folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Update Folder Is Expanded By Id
|
||||
############################
|
||||
|
||||
|
||||
class FolderIsExpandedForm(BaseModel):
|
||||
is_expanded: bool
|
||||
|
||||
|
||||
@router.post("/{id}/update/expanded")
|
||||
async def update_folder_is_expanded_by_id(
|
||||
id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user)
|
||||
):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
try:
|
||||
folder = Folders.update_folder_is_expanded_by_id_and_user_id(
|
||||
id, user.id, form_data.is_expanded
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error updating folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Delete Folder By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
|
||||
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
|
||||
if folder:
|
||||
try:
|
||||
result = Folders.delete_folder_by_id_and_user_id(id, user.id)
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
raise Exception("Error deleting folder")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
log.error(f"Error deleting folder: {id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
@@ -876,6 +876,12 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.tags.prompt_template",
|
||||
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
ENABLE_SEARCH_QUERY = PersistentConfig(
|
||||
"ENABLE_SEARCH_QUERY",
|
||||
"task.search.enable",
|
||||
|
||||
@@ -20,7 +20,9 @@ class ERROR_MESSAGES(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
||||
DEFAULT = lambda err="": f"Something went wrong :/\n[ERROR: {err if err else ''}]"
|
||||
DEFAULT = (
|
||||
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
|
||||
)
|
||||
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
|
||||
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
|
||||
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
|
||||
@@ -106,6 +108,7 @@ class TASKS(str, Enum):
|
||||
|
||||
DEFAULT = lambda task="": f"{task if task else 'generation'}"
|
||||
TITLE_GENERATION = "title_generation"
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
|
||||
@@ -82,6 +82,7 @@ from open_webui.config import (
|
||||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
WEBHOOK_URL,
|
||||
WEBUI_AUTH,
|
||||
@@ -118,6 +119,7 @@ from open_webui.utils.response import (
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.task import (
|
||||
moa_response_generation_template,
|
||||
tags_generation_template,
|
||||
search_query_generation_template,
|
||||
title_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
@@ -194,6 +196,7 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
app.state.config.TASK_MODEL = TASK_MODEL
|
||||
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
@@ -1403,6 +1406,7 @@ async def get_task_config(user=Depends(get_verified_user)):
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
@@ -1413,6 +1417,7 @@ class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_SEARCH_QUERY: bool
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||
@@ -1425,6 +1430,10 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
@@ -1437,6 +1446,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
@@ -1521,6 +1531,75 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/tags/completions")
|
||||
async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_chat_tags")
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = """### Task:
|
||||
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
|
||||
|
||||
### Guidelines:
|
||||
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
|
||||
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
|
||||
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
|
||||
- Use the chat's primary language; default to English if multilingual
|
||||
- Prioritize accuracy over specificity
|
||||
|
||||
### Output:
|
||||
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
content = tags_generation_template(
|
||||
template, form_data["messages"], {"name": user.name}
|
||||
)
|
||||
|
||||
print("content", content)
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
|
||||
}
|
||||
log.debug(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
return await generate_chat_completions(form_data=payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/query/completions")
|
||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_search_query")
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Update file table path
|
||||
|
||||
Revision ID: c29facfe716b
|
||||
Revises: c69f45358db4
|
||||
Create Date: 2024-10-20 17:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import String, Text, JSON, and_
|
||||
|
||||
|
||||
revision = "c29facfe716b"
|
||||
down_revision = "c69f45358db4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# 1. Add the `path` column to the "file" table.
|
||||
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
|
||||
|
||||
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
|
||||
# Use Alembic's default batch_op for dialect compatibility.
|
||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"meta",
|
||||
type_=sa.JSON(),
|
||||
existing_type=sa.Text(),
|
||||
existing_nullable=True,
|
||||
nullable=True,
|
||||
postgresql_using="meta::json",
|
||||
)
|
||||
|
||||
# 3. Migrate legacy data from `meta` JSONField
|
||||
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
|
||||
# We will use SQLAlchemy core bindings to ensure safety across different databases.
|
||||
|
||||
file_table = table(
|
||||
"file", column("id", String), column("meta", JSON), column("path", Text)
|
||||
)
|
||||
|
||||
# Create connection to the database
|
||||
connection = op.get_bind()
|
||||
|
||||
# Get the rows where `meta` has a path and `path` column is null (new column)
|
||||
# Loop through each row in the result set to update the path
|
||||
results = connection.execute(
|
||||
sa.select(file_table.c.id, file_table.c.meta).where(
|
||||
and_(file_table.c.path.is_(None), file_table.c.meta.isnot(None))
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Iterate over each row to extract and update the `path` from `meta` column
|
||||
for row in results:
|
||||
if "path" in row.meta:
|
||||
# Extract the `path` field from the `meta` JSON
|
||||
path = row.meta.get("path")
|
||||
|
||||
# Update the `file` table with the new `path` value
|
||||
connection.execute(
|
||||
file_table.update()
|
||||
.where(file_table.c.id == row.id)
|
||||
.values({"path": path})
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# 1. Remove the `path` column
|
||||
op.drop_column("file", "path")
|
||||
|
||||
# 2. Revert the `meta` column back to Text/JSONField
|
||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Add folder table
|
||||
|
||||
Revision ID: c69f45358db4
|
||||
Revises: 3ab32c4b8f59
|
||||
Create Date: 2024-10-16 02:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c69f45358db4"
|
||||
down_revision = "3ab32c4b8f59"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"folder",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("items", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "user_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat",
|
||||
sa.Column("folder_id", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("chat", "folder_id")
|
||||
|
||||
op.drop_table("folder")
|
||||
@@ -123,6 +123,24 @@ def replace_messages_variable(template: str, messages: list[str]) -> str:
|
||||
return template
|
||||
|
||||
|
||||
def tags_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def search_query_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
Reference in New Issue
Block a user