mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 11:27:46 +01:00
wip
This commit is contained in:
@@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models):
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
if not (
|
||||
request.method == "POST"
|
||||
and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
)
|
||||
):
|
||||
return await call_next(request)
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
@@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware)
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
if not (
|
||||
request.method == "POST"
|
||||
and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
)
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
@@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}):
|
||||
return openai_chat_completion_message_template(form_data["model"], message)
|
||||
|
||||
|
||||
async def get_all_base_models():
|
||||
async def get_all_base_models(request):
|
||||
function_models = []
|
||||
openai_models = []
|
||||
ollama_models = []
|
||||
|
||||
if app.state.config.ENABLE_OPENAI_API:
|
||||
openai_models = await openai.get_all_models()
|
||||
openai_models = await openai.get_all_models(request)
|
||||
openai_models = openai_models["data"]
|
||||
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
ollama_models = await ollama.get_all_models()
|
||||
ollama_models = await ollama.get_all_models(request)
|
||||
ollama_models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
@@ -1729,8 +1735,8 @@ async def get_all_base_models():
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models():
|
||||
models = await get_all_base_models()
|
||||
async def get_all_models(request):
|
||||
models = await get_all_base_models(request)
|
||||
|
||||
# If there are no models, return an empty list
|
||||
if len([model for model in models if not model.get("arena", False)]) == 0:
|
||||
@@ -1859,8 +1865,8 @@ async def get_all_models():
|
||||
|
||||
|
||||
@app.get("/api/models")
|
||||
async def get_models(user=Depends(get_verified_user)):
|
||||
models = await get_all_models()
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
models = await get_all_models(request)
|
||||
|
||||
# Filter out filter pipelines
|
||||
models = [
|
||||
@@ -2042,7 +2048,7 @@ async def generate_chat_completions(
|
||||
async def chat_completed(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
model_list = await get_all_models()
|
||||
model_list = await get_all_models(request)
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
data = form_data
|
||||
|
||||
Reference in New Issue
Block a user