wip: user groups frontend

This commit is contained in:
Timothy Jaeryang Baek
2024-11-13 03:09:46 -08:00
parent 6caf838964
commit baea26d9ca
9 changed files with 572 additions and 289 deletions

View File

@@ -13,9 +13,7 @@ import requests
from open_webui.apps.webui.models.models import Models
from open_webui.config import (
CORS_ALLOW_ORIGIN,
ENABLE_MODEL_FILTER,
ENABLE_OLLAMA_API,
MODEL_FILTER_LIST,
OLLAMA_BASE_URLS,
OLLAMA_API_CONFIGS,
UPLOAD_DIR,
@@ -66,9 +64,6 @@ app.add_middleware(
app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
@@ -339,16 +334,18 @@ async def get_ollama_tags(
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)
return models
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user":
# models["models"] = list(
# filter(
# lambda model: model["name"]
# in app.state.config.MODEL_FILTER_LIST,
# models["models"],
# )
# )
# return models
return models
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -922,12 +919,14 @@ async def generate_chat_completion(
model_id = form_data.model
if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=403,
detail="Model not found",
)
# TODO: Check User Group and Filter Models
# if not bypass_filter:
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
# raise HTTPException(
# status_code=403,
# detail="Model not found",
# )
model_info = Models.get_model_by_id(model_id)
@@ -1008,12 +1007,13 @@ async def generate_openai_chat_completion(
model_id = completion_form.model
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
raise HTTPException(
status_code=403,
detail="Model not found",
)
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
# raise HTTPException(
# status_code=403,
# detail="Model not found",
# )
model_info = Models.get_model_by_id(model_id)
@@ -1054,15 +1054,16 @@ async def get_openai_models(
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"],
)
)
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user":
# models["models"] = list(
# filter(
# lambda model: model["name"]
# in app.state.config.MODEL_FILTER_LIST,
# models["models"],
# )
# )
return {
"data": [

View File

@@ -11,9 +11,7 @@ from open_webui.apps.webui.models.models import Models
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
ENABLE_MODEL_FILTER,
ENABLE_OPENAI_API,
MODEL_FILTER_LIST,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
@@ -61,9 +59,6 @@ app.add_middleware(
app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
@@ -372,15 +367,18 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models["data"],
)
)
return models
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user":
# models["data"] = list(
# filter(
# lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
# models["data"],
# )
# )
# return models
return models
else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
@@ -492,11 +490,10 @@ async def verify_connection(
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
):
idx = 0
payload = {**form_data}
@@ -505,6 +502,16 @@ async def generate_chat_completion(
del payload["metadata"]
model_id = form_data.get("model")
# TODO: Check User Group and Filter Models
# if not bypass_filter:
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
# raise HTTPException(
# status_code=403,
# detail="Model not found",
# )
model_info = Models.get_model_by_id(model_id)
if model_info: