mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
Merge branch 'dev' into feat/model-config
This commit is contained in:
@@ -397,7 +397,7 @@ def generate_image(
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
|
||||
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
||||
@@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
|
||||
litellm_config = yaml.safe_load(file)
|
||||
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
|
||||
|
||||
|
||||
app.state.ENABLE = ENABLE_LITELLM
|
||||
app.state.CONFIG = litellm_config
|
||||
app.state.MODEL_CONFIG = MODEL_CONFIG.value.get("litellm", [])
|
||||
@@ -152,10 +156,6 @@ async def shutdown_litellm_background():
|
||||
background_process = None
|
||||
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
@@ -65,8 +65,8 @@ app.add_middleware(
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.MODEL_CONFIG = MODEL_CONFIG.value.get("ollama", [])
|
||||
|
||||
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
@@ -126,8 +126,9 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user))
|
||||
|
||||
|
||||
async def fetch_url(url):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -190,11 +191,12 @@ async def get_ollama_tags(
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["models"] = list(
|
||||
filter(
|
||||
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["name"]
|
||||
in app.state.config.MODEL_FILTER_LIST,
|
||||
models["models"],
|
||||
)
|
||||
)
|
||||
@@ -1058,11 +1060,12 @@ async def get_openai_models(
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["models"] = list(
|
||||
filter(
|
||||
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["name"]
|
||||
in app.state.config.MODEL_FILTER_LIST,
|
||||
models["models"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from utils.utils import (
|
||||
)
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
ENABLE_OPENAI_API,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
CACHE_DIR,
|
||||
@@ -47,12 +48,15 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.MODEL_CONFIG = MODEL_CONFIG.value.get("openai", [])
|
||||
|
||||
|
||||
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
||||
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||
|
||||
@@ -70,6 +74,21 @@ async def check_url(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(user=Depends(get_admin_user)):
|
||||
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
|
||||
|
||||
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
enable_openai_api: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
|
||||
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
|
||||
|
||||
|
||||
class UrlsUpdateForm(BaseModel):
|
||||
urls: List[str]
|
||||
|
||||
@@ -166,11 +185,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
async def fetch_url(url, key):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return await response.json()
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return await response.json()
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
log.error(f"Connection error: {e}")
|
||||
@@ -202,7 +225,7 @@ async def get_all_models():
|
||||
if (
|
||||
len(app.state.config.OPENAI_API_KEYS) == 1
|
||||
and app.state.config.OPENAI_API_KEYS[0] == ""
|
||||
):
|
||||
) or not app.state.config.ENABLE_OPENAI_API:
|
||||
models = {"data": []}
|
||||
else:
|
||||
tasks = [
|
||||
@@ -248,11 +271,11 @@ def add_custom_info_to_model(model: dict):
|
||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["data"] = list(
|
||||
filter(
|
||||
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
models["data"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -433,12 +433,12 @@ async def update_query_settings(
|
||||
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.config.RAG_TEMPLATE = (
|
||||
form_data.template if form_data.template else RAG_TEMPLATE,
|
||||
form_data.template if form_data.template else RAG_TEMPLATE
|
||||
)
|
||||
app.state.config.TOP_K = form_data.k if form_data.k else 4
|
||||
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
|
||||
form_data.hybrid if form_data.hybrid else False,
|
||||
form_data.hybrid if form_data.hybrid else False
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
|
||||
@@ -11,8 +11,9 @@ import logging
|
||||
|
||||
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
|
||||
from apps.web.models.auths import Auths
|
||||
from apps.web.models.chats import Chats
|
||||
|
||||
from utils.utils import get_current_user, get_password_hash, get_admin_user
|
||||
from utils.utils import get_verified_user, get_password_hash, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
@@ -67,6 +68,41 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserById
|
||||
############################
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
name: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
|
||||
if user_id.startswith("shared-"):
|
||||
chat_id = user_id.replace("shared-", "")
|
||||
chat = Chats.get_chat_by_id(chat_id)
|
||||
if chat:
|
||||
user_id = chat.user_id
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(user_id)
|
||||
|
||||
if user:
|
||||
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserById
|
||||
############################
|
||||
|
||||
@@ -417,6 +417,14 @@ OLLAMA_BASE_URLS = PersistentConfig(
|
||||
# OPENAI_API
|
||||
####################################
|
||||
|
||||
|
||||
ENABLE_OPENAI_API = PersistentConfig(
|
||||
"ENABLE_OPENAI_API",
|
||||
"openai.enable",
|
||||
os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
||||
|
||||
|
||||
@@ -120,6 +120,18 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
origins = ["*"]
|
||||
|
||||
|
||||
# Custom middleware to add security headers
|
||||
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
# async def dispatch(self, request: Request, call_next):
|
||||
# response: Response = await call_next(request)
|
||||
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
||||
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
||||
# return response
|
||||
|
||||
|
||||
# app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
class RAGMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
return_citations = False
|
||||
@@ -280,14 +292,14 @@ class ModelFilterConfigForm(BaseModel):
|
||||
async def update_model_filter_config(
|
||||
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
|
||||
app.state.config.MODEL_FILTER_LIST, form_data.models
|
||||
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
||||
app.state.config.MODEL_FILTER_LIST = form_data.models
|
||||
|
||||
ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
|
||||
openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
|
||||
litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
|
||||
Reference in New Issue
Block a user