Merge remote-tracking branch 'origin/dev' into feat/backend-web-search

This commit is contained in:
Jun Siang Cheah
2024-05-20 18:01:29 +01:00
103 changed files with 4496 additions and 694 deletions

View File

@@ -397,7 +397,7 @@ def generate_image(
user=Depends(get_current_user),
):
width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
r = None
try:

View File

@@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config
@@ -151,10 +155,6 @@ async def shutdown_litellm_background():
background_process = None
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.get("/")
async def get_status():
return {"status": True}

View File

@@ -64,8 +64,8 @@ app.add_middleware(
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
@@ -124,8 +124,9 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user))
async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5)
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as response:
return await response.json()
except Exception as e:
@@ -177,11 +178,12 @@ async def get_ollama_tags(
if url_idx == None:
models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER:
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)
@@ -1045,11 +1047,12 @@ async def get_openai_models(
if url_idx == None:
models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER:
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)

View File

@@ -21,6 +21,7 @@ from utils.utils import (
)
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
@@ -46,11 +47,14 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
@@ -68,6 +72,21 @@ async def check_url(request: Request, call_next):
return response
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
class OpenAIConfigForm(BaseModel):
enable_openai_api: Optional[bool] = None
@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
class UrlsUpdateForm(BaseModel):
urls: List[str]
@@ -164,11 +183,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=5)
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
return await response.json()
if key != "":
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
else:
return None
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
@@ -200,7 +223,7 @@ async def get_all_models():
if (
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
):
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
tasks = [
@@ -237,11 +260,11 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None:
models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER:
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models["data"],
)
)

View File

@@ -70,6 +70,7 @@ from utils.misc import (
from utils.utils import get_current_user, get_admin_user
from config import (
ENV,
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
@@ -266,7 +267,7 @@ async def update_embedding_config(
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
@@ -439,12 +440,12 @@ async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.config.RAG_TEMPLATE = (
form_data.template if form_data.template else RAG_TEMPLATE,
form_data.template if form_data.template else RAG_TEMPLATE
)
app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False,
form_data.hybrid if form_data.hybrid else False
)
return {
"status": True,
@@ -1006,3 +1007,14 @@ def reset(user=Depends(get_admin_user)) -> bool:
log.exception(e)
return True
if ENV == "dev":
@app.get("/ef")
async def get_embeddings():
return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
@app.get("/ef/{text}")
async def get_embeddings_text(text: str):
return {"result": app.state.EMBEDDING_FUNCTION(text)}

View File

@@ -0,0 +1,53 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Memory(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
content = pw.TextField(null=False)
updated_at = pw.BigIntegerField(null=False)
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "memory"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("memory")

View File

@@ -9,6 +9,7 @@ from apps.web.routers import (
modelfiles,
prompts,
configs,
memories,
utils,
)
from config import (
@@ -41,6 +42,7 @@ app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
@@ -52,9 +54,12 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])

View File

@@ -0,0 +1,118 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
import time
import uuid
####################
# Memory DB Schema
####################
class Memory(Model):
id = CharField(unique=True)
user_id = CharField()
content = TextField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class MemoryModel(BaseModel):
id: str
user_id: str
content: str
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class MemoriesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Memory])
def insert_new_memory(
self,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
id = str(uuid.uuid4())
memory = MemoryModel(
**{
"id": id,
"user_id": user_id,
"content": content,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Memory.create(**memory.model_dump())
if result:
return memory
else:
return None
def get_memories(self) -> List[MemoryModel]:
try:
memories = Memory.select()
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
except:
return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
try:
memories = Memory.select().where(Memory.user_id == user_id)
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
except:
return None
def get_memory_by_id(self, id) -> Optional[MemoryModel]:
try:
memory = Memory.get(Memory.id == id)
return MemoryModel(**model_to_dict(memory))
except:
return None
def delete_memory_by_id(self, id: str) -> bool:
try:
query = Memory.delete().where(Memory.id == id)
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
def delete_memories_by_user_id(self, user_id: str) -> bool:
try:
query = Memory.delete().where(Memory.user_id == user_id)
query.execute()
return True
except:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
query.execute()
return True
except:
return False
Memories = MemoriesTable(DB)

View File

@@ -0,0 +1,145 @@
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import logging
from apps.web.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
@router.get("/ef")
async def get_embeddings(request: Request):
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
############################
# GetMemories
############################
@router.get("/", response_model=List[MemoryModel])
async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id)
############################
# AddMemory
############################
class AddMemoryForm(BaseModel):
content: str
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[{"created_at": memory.created_at}],
)
return memory
############################
# QueryMemory
############################
class QueryMemoryForm(BaseModel):
content: str
@router.post("/query")
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
results = collection.query(
query_embeddings=[query_embedding],
n_results=1, # how many results to return
)
return results
############################
# ResetMemoryFromVectorDB
############################
@router.get("/reset", response_model=bool)
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
)
return True
############################
# DeleteMemoriesByUserId
############################
@router.delete("/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(user.id)
if result:
try:
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
except Exception as e:
log.error(e)
return True
return False
############################
# DeleteMemoryById
############################
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
collection = CHROMA_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.delete(ids=[memory_id])
return True
return False

View File

@@ -11,8 +11,9 @@ import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
from apps.web.models.chats import Chats
from utils.utils import get_current_user, get_password_hash, get_admin_user
from utils.utils import get_verified_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS
@@ -67,6 +68,41 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
)
############################
# GetUserById
############################
class UserResponse(BaseModel):
name: str
profile_image_url: str
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(chat_id)
if chat:
user_id = chat.user_id
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
user = Users.get_user_by_id(user_id)
if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# UpdateUserById
############################

View File

@@ -417,6 +417,14 @@ OLLAMA_BASE_URLS = PersistentConfig(
# OPENAI_API
####################################
ENABLE_OPENAI_API = PersistentConfig(
"ENABLE_OPENAI_API",
"openai.enable",
os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true",
)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")

View File

@@ -118,6 +118,18 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next):
# response: Response = await call_next(request)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response
# app.add_middleware(SecurityHeadersMiddleware)
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
@@ -227,9 +239,15 @@ async def check_url(request: Request, call_next):
return response
app.mount("/api/v1", webui_app)
app.mount("/litellm/api", litellm_app)
@app.middleware("http")
async def update_embedding_function(request: Request, call_next):
response = await call_next(request)
if "/embedding/update" in request.url.path:
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
return response
app.mount("/litellm/api", litellm_app)
app.mount("/ollama", ollama_app)
app.mount("/openai/api", openai_app)
@@ -237,6 +255,10 @@ app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app)
app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
@app.get("/api/config")
async def get_app_config():
@@ -279,14 +301,14 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
app.state.config.MODEL_FILTER_LIST, form_data.models
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models
ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST