Merge branch 'websearch' into feat/backend-web-search

This commit is contained in:
Timothy Jaeryang Baek
2024-05-26 23:40:05 -07:00
committed by GitHub
77 changed files with 1338 additions and 345 deletions

View File

@@ -198,7 +198,7 @@ async def fetch_url(url, key):
def merge_models_lists(model_lists):
log.info(f"merge_models_lists {model_lists}")
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
@@ -237,7 +237,7 @@ async def get_all_models():
]
responses = await asyncio.gather(*tasks)
log.info(f"get_all_models:responses() {responses}")
log.debug(f"get_all_models:responses() {responses}")
models = {
"data": merge_models_lists(
@@ -254,7 +254,7 @@ async def get_all_models():
)
}
log.info(f"models: {models}")
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models

View File

@@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields settings to the 'user' table
migrator.add_fields("user", settings=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "settings")

View File

@@ -13,7 +13,7 @@ from apps.webui.routers import (
utils,
)
from config import (
WEBUI_VERSION,
WEBUI_BUILD_HASH,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
@@ -23,7 +23,9 @@ from config import (
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
)
app = FastAPI()
@@ -40,7 +42,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER

View File

@@ -1,11 +1,11 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB
from apps.webui.internal.db import DB, JSONField
from apps.webui.models.chats import Chats
####################
@@ -25,11 +25,18 @@ class User(Model):
created_at = BigIntegerField()
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
class Meta:
database = DB
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
pass
class UserModel(BaseModel):
id: str
name: str
@@ -42,6 +49,7 @@ class UserModel(BaseModel):
created_at: int # timestamp in epoch
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
####################

View File

@@ -8,6 +8,8 @@ from pydantic import BaseModel
import time
import uuid
from config import BannerModel
from apps.webui.models.users import Users
from utils.utils import (
@@ -57,3 +59,31 @@ async def set_global_default_suggestions(
data = form_data.model_dump()
request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
############################
# SetBanners
############################
class SetBannersForm(BaseModel):
banners: List[BannerModel]
@router.post("/banners", response_model=List[BannerModel])
async def set_banners(
request: Request,
form_data: SetBannersForm,
user=Depends(get_admin_user),
):
data = form_data.model_dump()
request.app.state.config.BANNERS = data["banners"]
return request.app.state.config.BANNERS
@router.get("/banners", response_model=List[BannerModel])
async def get_banners(
request: Request,
user=Depends(get_current_user),
):
return request.app.state.config.BANNERS

View File

@@ -9,7 +9,13 @@ import time
import uuid
import logging
from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.webui.models.users import (
UserModel,
UserUpdateForm,
UserRoleUpdateForm,
UserSettings,
Users,
)
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
@@ -68,6 +74,42 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
)
############################
# GetUserSettingsBySessionUser
############################
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.settings
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# UpdateUserSettingsBySessionUser
############################
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
if user:
return user.settings
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# GetUserById
############################
@@ -81,6 +123,8 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(chat_id)