This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 20:15:23 -08:00
parent fe5519e0a2
commit a07ff56c50
2 changed files with 55 additions and 50 deletions

View File

@@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models):
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
if not (
request.method == "POST"
and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
):
return await call_next(request)
log.debug(f"request.url.path: {request.url.path}")
@@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware)
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
if not (
request.method == "POST"
and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
):
return await call_next(request)
@@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}):
return openai_chat_completion_message_template(form_data["model"], message)
async def get_all_base_models():
async def get_all_base_models(request):
function_models = []
openai_models = []
ollama_models = []
if app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models()
openai_models = await openai.get_all_models(request)
openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models()
ollama_models = await ollama.get_all_models(request)
ollama_models = [
{
"id": model["model"],
@@ -1729,8 +1735,8 @@ async def get_all_base_models():
@cached(ttl=3)
async def get_all_models():
models = await get_all_base_models()
async def get_all_models(request):
models = await get_all_base_models(request)
# If there are no models, return an empty list
if len([model for model in models if not model.get("arena", False)]) == 0:
@@ -1859,8 +1865,8 @@ async def get_all_models():
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
async def get_models(request: Request, user=Depends(get_verified_user)):
models = await get_all_models(request)
# Filter out filter pipelines
models = [
@@ -2042,7 +2048,7 @@ async def generate_chat_completions(
async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models()
model_list = await get_all_models(request)
models = {model["id"]: model for model in model_list}
data = form_data