mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
feat: switch to config proxy, remove config_get/set
This commit is contained in:
@@ -42,8 +42,7 @@ from config import (
|
||||
IMAGE_GENERATION_MODEL,
|
||||
IMAGE_SIZE,
|
||||
IMAGE_STEPS,
|
||||
config_get,
|
||||
config_set,
|
||||
AppConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -62,28 +61,30 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.state.ENGINE = IMAGE_GENERATION_ENGINE
|
||||
app.state.ENABLED = ENABLE_IMAGE_GENERATION
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
|
||||
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
|
||||
|
||||
app.state.MODEL = IMAGE_GENERATION_MODEL
|
||||
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||
|
||||
app.state.config.MODEL = IMAGE_GENERATION_MODEL
|
||||
|
||||
|
||||
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
|
||||
|
||||
app.state.IMAGE_SIZE = IMAGE_SIZE
|
||||
app.state.IMAGE_STEPS = IMAGE_STEPS
|
||||
app.state.config.IMAGE_SIZE = IMAGE_SIZE
|
||||
app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"engine": config_get(app.state.ENGINE),
|
||||
"enabled": config_get(app.state.ENABLED),
|
||||
"engine": app.state.config.ENGINE,
|
||||
"enabled": app.state.config.ENABLED,
|
||||
}
|
||||
|
||||
|
||||
@@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel):
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
|
||||
config_set(app.state.ENGINE, form_data.engine)
|
||||
config_set(app.state.ENABLED, form_data.enabled)
|
||||
app.state.config.ENGINE = form_data.engine
|
||||
app.state.config.ENABLED = form_data.enabled
|
||||
return {
|
||||
"engine": config_get(app.state.ENGINE),
|
||||
"enabled": config_get(app.state.ENABLED),
|
||||
"engine": app.state.config.ENGINE,
|
||||
"enabled": app.state.config.ENABLED,
|
||||
}
|
||||
|
||||
|
||||
@@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
|
||||
@app.get("/url")
|
||||
async def get_engine_url(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
|
||||
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
|
||||
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
||||
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
||||
}
|
||||
|
||||
|
||||
@@ -121,29 +122,29 @@ async def update_engine_url(
|
||||
):
|
||||
|
||||
if form_data.AUTOMATIC1111_BASE_URL == None:
|
||||
config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL))
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
else:
|
||||
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
|
||||
try:
|
||||
r = requests.head(url)
|
||||
config_set(app.state.AUTOMATIC1111_BASE_URL, url)
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = url
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
if form_data.COMFYUI_BASE_URL == None:
|
||||
config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL)
|
||||
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
else:
|
||||
url = form_data.COMFYUI_BASE_URL.strip("/")
|
||||
|
||||
try:
|
||||
r = requests.head(url)
|
||||
config_set(app.state.COMFYUI_BASE_URL, url)
|
||||
app.state.config.COMFYUI_BASE_URL = url
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
return {
|
||||
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
|
||||
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
|
||||
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
||||
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
||||
"status": True,
|
||||
}
|
||||
|
||||
@@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
|
||||
@app.get("/openai/config")
|
||||
async def get_openai_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
|
||||
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
|
||||
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
@@ -168,13 +169,13 @@ async def update_openai_config(
|
||||
if form_data.key == "":
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||
|
||||
config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
|
||||
config_set(app.state.OPENAI_API_KEY, form_data.key)
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.key
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
|
||||
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
|
||||
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
||||
}
|
||||
|
||||
|
||||
@@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
|
||||
|
||||
@app.get("/size")
|
||||
async def get_image_size(user=Depends(get_admin_user)):
|
||||
return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)}
|
||||
return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
|
||||
|
||||
|
||||
@app.post("/size/update")
|
||||
@@ -193,9 +194,9 @@ async def update_image_size(
|
||||
):
|
||||
pattern = r"^\d+x\d+$" # Regular expression pattern
|
||||
if re.match(pattern, form_data.size):
|
||||
config_set(app.state.IMAGE_SIZE, form_data.size)
|
||||
app.state.config.IMAGE_SIZE = form_data.size
|
||||
return {
|
||||
"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE),
|
||||
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
||||
"status": True,
|
||||
}
|
||||
else:
|
||||
@@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
|
||||
|
||||
@app.get("/steps")
|
||||
async def get_image_size(user=Depends(get_admin_user)):
|
||||
return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)}
|
||||
return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
|
||||
|
||||
|
||||
@app.post("/steps/update")
|
||||
@@ -219,9 +220,9 @@ async def update_image_size(
|
||||
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
if form_data.steps >= 0:
|
||||
config_set(app.state.IMAGE_STEPS, form_data.steps)
|
||||
app.state.config.IMAGE_STEPS = form_data.steps
|
||||
return {
|
||||
"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS),
|
||||
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
|
||||
"status": True,
|
||||
}
|
||||
else:
|
||||
@@ -234,14 +235,14 @@ async def update_image_size(
|
||||
@app.get("/models")
|
||||
def get_models(user=Depends(get_current_user)):
|
||||
try:
|
||||
if app.state.ENGINE == "openai":
|
||||
if app.state.config.ENGINE == "openai":
|
||||
return [
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
]
|
||||
elif app.state.ENGINE == "comfyui":
|
||||
elif app.state.config.ENGINE == "comfyui":
|
||||
|
||||
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
|
||||
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
||||
info = r.json()
|
||||
|
||||
return list(
|
||||
@@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
|
||||
|
||||
else:
|
||||
r = requests.get(
|
||||
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
|
||||
)
|
||||
models = r.json()
|
||||
return list(
|
||||
@@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)):
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
app.state.ENABLED = False
|
||||
app.state.config.ENABLED = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
|
||||
@app.get("/models/default")
|
||||
async def get_default_model(user=Depends(get_admin_user)):
|
||||
try:
|
||||
if app.state.ENGINE == "openai":
|
||||
if app.state.config.ENGINE == "openai":
|
||||
return {
|
||||
"model": (
|
||||
config_get(app.state.MODEL)
|
||||
if config_get(app.state.MODEL)
|
||||
else "dall-e-2"
|
||||
)
|
||||
}
|
||||
elif app.state.ENGINE == "comfyui":
|
||||
return {
|
||||
"model": (
|
||||
config_get(app.state.MODEL) if config_get(app.state.MODEL) else ""
|
||||
app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
|
||||
)
|
||||
}
|
||||
elif app.state.config.ENGINE == "comfyui":
|
||||
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
|
||||
else:
|
||||
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
||||
r = requests.get(
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
|
||||
)
|
||||
options = r.json()
|
||||
return {"model": options["sd_model_checkpoint"]}
|
||||
except Exception as e:
|
||||
config_set(app.state.ENABLED, False)
|
||||
app.state.config.ENABLED = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
|
||||
@@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel):
|
||||
|
||||
|
||||
def set_model_handler(model: str):
|
||||
if app.state.ENGINE in ["openai", "comfyui"]:
|
||||
config_set(app.state.MODEL, model)
|
||||
return config_get(app.state.MODEL)
|
||||
if app.state.config.ENGINE in ["openai", "comfyui"]:
|
||||
app.state.config.MODEL = model
|
||||
return app.state.config.MODEL
|
||||
else:
|
||||
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
|
||||
r = requests.get(
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
|
||||
)
|
||||
options = r.json()
|
||||
|
||||
if model != options["sd_model_checkpoint"]:
|
||||
options["sd_model_checkpoint"] = model
|
||||
r = requests.post(
|
||||
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
json=options,
|
||||
)
|
||||
|
||||
return options
|
||||
@@ -397,30 +397,32 @@ def generate_image(
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x")))
|
||||
width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
|
||||
|
||||
r = None
|
||||
try:
|
||||
if app.state.ENGINE == "openai":
|
||||
if app.state.config.ENGINE == "openai":
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
|
||||
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
data = {
|
||||
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
|
||||
"model": (
|
||||
app.state.config.MODEL
|
||||
if app.state.config.MODEL != ""
|
||||
else "dall-e-2"
|
||||
),
|
||||
"prompt": form_data.prompt,
|
||||
"n": form_data.n,
|
||||
"size": (
|
||||
form_data.size
|
||||
if form_data.size
|
||||
else config_get(app.state.IMAGE_SIZE)
|
||||
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
|
||||
),
|
||||
"response_format": "b64_json",
|
||||
}
|
||||
|
||||
r = requests.post(
|
||||
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
|
||||
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -440,7 +442,7 @@ def generate_image(
|
||||
|
||||
return images
|
||||
|
||||
elif app.state.ENGINE == "comfyui":
|
||||
elif app.state.config.ENGINE == "comfyui":
|
||||
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
@@ -449,8 +451,8 @@ def generate_image(
|
||||
"n": form_data.n,
|
||||
}
|
||||
|
||||
if config_get(app.state.IMAGE_STEPS) is not None:
|
||||
data["steps"] = config_get(app.state.IMAGE_STEPS)
|
||||
if app.state.config.IMAGE_STEPS is not None:
|
||||
data["steps"] = app.state.config.IMAGE_STEPS
|
||||
|
||||
if form_data.negative_prompt is not None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
@@ -458,10 +460,10 @@ def generate_image(
|
||||
data = ImageGenerationPayload(**data)
|
||||
|
||||
res = comfyui_generate_image(
|
||||
config_get(app.state.MODEL),
|
||||
app.state.config.MODEL,
|
||||
data,
|
||||
user.id,
|
||||
config_get(app.state.COMFYUI_BASE_URL),
|
||||
app.state.config.COMFYUI_BASE_URL,
|
||||
)
|
||||
log.debug(f"res: {res}")
|
||||
|
||||
@@ -488,14 +490,14 @@ def generate_image(
|
||||
"height": height,
|
||||
}
|
||||
|
||||
if config_get(app.state.IMAGE_STEPS) is not None:
|
||||
data["steps"] = config_get(app.state.IMAGE_STEPS)
|
||||
if app.state.config.IMAGE_STEPS is not None:
|
||||
data["steps"] = app.state.config.IMAGE_STEPS
|
||||
|
||||
if form_data.negative_prompt is not None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
|
||||
r = requests.post(
|
||||
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
||||
json=data,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user