mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
feat: upload pipeline
This commit is contained in:
@@ -9,8 +9,11 @@ import logging
|
||||
import aiohttp
|
||||
import requests
|
||||
import mimetypes
|
||||
import shutil
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status
|
||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import HTTPException
|
||||
@@ -30,7 +33,7 @@ from apps.images.main import app as images_app
|
||||
from apps.rag.main import app as rag_app
|
||||
from apps.webui.main import app as webui_app
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -574,6 +577,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/pipelines/upload")
|
||||
async def upload_pipeline(
|
||||
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
# Check if the uploaded file is a python file
|
||||
if not file.filename.endswith(".py"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only Python (.py) files are allowed.",
|
||||
)
|
||||
|
||||
upload_folder = f"{CACHE_DIR}/pipelines"
|
||||
os.makedirs(upload_folder, exist_ok=True)
|
||||
file_path = os.path.join(upload_folder, file.filename)
|
||||
|
||||
try:
|
||||
# Save the uploaded file
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": f}
|
||||
r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
finally:
|
||||
# Ensure the file is deleted after the upload is completed or on failure
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
class AddPipelineForm(BaseModel):
|
||||
url: str
|
||||
urlIdx: int
|
||||
|
||||
Reference in New Issue
Block a user