mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Merge branch 'websearch' into feat/backend-web-search
This commit is contained in:
@@ -198,7 +198,7 @@ async def fetch_url(url, key):
|
||||
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
log.info(f"merge_models_lists {model_lists}")
|
||||
log.debug(f"merge_models_lists {model_lists}")
|
||||
merged_list = []
|
||||
|
||||
for idx, models in enumerate(model_lists):
|
||||
@@ -237,7 +237,7 @@ async def get_all_models():
|
||||
]
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
log.info(f"get_all_models:responses() {responses}")
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
|
||||
models = {
|
||||
"data": merge_models_lists(
|
||||
@@ -254,7 +254,7 @@ async def get_all_models():
|
||||
)
|
||||
}
|
||||
|
||||
log.info(f"models: {models}")
|
||||
log.debug(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
|
||||
return models
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields settings to the 'user' table
|
||||
migrator.add_fields("user", settings=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "settings")
|
||||
@@ -13,7 +13,7 @@ from apps.webui.routers import (
|
||||
utils,
|
||||
)
|
||||
from config import (
|
||||
WEBUI_VERSION,
|
||||
WEBUI_BUILD_HASH,
|
||||
WEBUI_AUTH,
|
||||
DEFAULT_MODELS,
|
||||
DEFAULT_PROMPT_SUGGESTIONS,
|
||||
@@ -23,7 +23,9 @@ from config import (
|
||||
WEBHOOK_URL,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
JWT_EXPIRES_IN,
|
||||
WEBUI_BANNERS,
|
||||
AppConfig,
|
||||
ENABLE_COMMUNITY_SHARING,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
@@ -40,7 +42,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
app.state.config.BANNERS = WEBUI_BANNERS
|
||||
|
||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||
|
||||
app.state.MODELS = {}
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
from utils.misc import get_gravatar_url
|
||||
|
||||
from apps.webui.internal.db import DB
|
||||
from apps.webui.internal.db import DB, JSONField
|
||||
from apps.webui.models.chats import Chats
|
||||
|
||||
####################
|
||||
@@ -25,11 +25,18 @@ class User(Model):
|
||||
created_at = BigIntegerField()
|
||||
|
||||
api_key = CharField(null=True, unique=True)
|
||||
settings = JSONField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
model_config = ConfigDict(extra="allow")
|
||||
pass
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -42,6 +49,7 @@ class UserModel(BaseModel):
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
api_key: Optional[str] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
|
||||
|
||||
####################
|
||||
|
||||
@@ -8,6 +8,8 @@ from pydantic import BaseModel
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from config import BannerModel
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
from utils.utils import (
|
||||
@@ -57,3 +59,31 @@ async def set_global_default_suggestions(
|
||||
data = form_data.model_dump()
|
||||
request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
|
||||
return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
|
||||
|
||||
|
||||
############################
|
||||
# SetBanners
|
||||
############################
|
||||
|
||||
|
||||
class SetBannersForm(BaseModel):
|
||||
banners: List[BannerModel]
|
||||
|
||||
|
||||
@router.post("/banners", response_model=List[BannerModel])
|
||||
async def set_banners(
|
||||
request: Request,
|
||||
form_data: SetBannersForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
data = form_data.model_dump()
|
||||
request.app.state.config.BANNERS = data["banners"]
|
||||
return request.app.state.config.BANNERS
|
||||
|
||||
|
||||
@router.get("/banners", response_model=List[BannerModel])
|
||||
async def get_banners(
|
||||
request: Request,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
return request.app.state.config.BANNERS
|
||||
|
||||
@@ -9,7 +9,13 @@ import time
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
|
||||
from apps.webui.models.users import (
|
||||
UserModel,
|
||||
UserUpdateForm,
|
||||
UserRoleUpdateForm,
|
||||
UserSettings,
|
||||
Users,
|
||||
)
|
||||
from apps.webui.models.auths import Auths
|
||||
from apps.webui.models.chats import Chats
|
||||
|
||||
@@ -68,6 +74,42 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserSettingsBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/user/settings", response_model=Optional[UserSettings])
|
||||
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserSettingsBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/user/settings/update", response_model=UserSettings)
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: UserSettings, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserById
|
||||
############################
|
||||
@@ -81,6 +123,8 @@ class UserResponse(BaseModel):
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if user_id is a shared chat
|
||||
# If it is, get the user_id from the chat
|
||||
if user_id.startswith("shared-"):
|
||||
chat_id = user_id.replace("shared-", "")
|
||||
chat = Chats.get_chat_by_id(chat_id)
|
||||
|
||||
Reference in New Issue
Block a user