feat: save UI config changes to config.json

This commit is contained in:
Jun Siang Cheah
2024-05-10 13:36:10 +08:00
parent 9a95767062
commit 058eb76568
11 changed files with 611 additions and 336 deletions

View File

@@ -42,6 +42,8 @@ from config import (
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
config_get,
config_set,
)
@@ -79,7 +81,10 @@ app.state.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
return {
"engine": config_get(app.state.ENGINE),
"enabled": config_get(app.state.ENABLED),
}
class ConfigUpdateForm(BaseModel):
@@ -89,9 +94,12 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine
app.state.ENABLED = form_data.enabled
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
config_set(app.state.ENGINE, form_data.engine)
config_set(app.state.ENABLED, form_data.enabled)
return {
"engine": config_get(app.state.ENGINE),
"enabled": config_get(app.state.ENABLED),
}
class EngineUrlUpdateForm(BaseModel):
@@ -102,8 +110,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)):
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
}
@@ -113,29 +121,29 @@ async def update_engine_url(
):
if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL))
else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url
config_set(app.state.AUTOMATIC1111_BASE_URL, url)
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL)
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.COMFYUI_BASE_URL = url
config_set(app.state.COMFYUI_BASE_URL, url)
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
"status": True,
}
@@ -148,8 +156,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
}
@@ -160,13 +168,13 @@ async def update_openai_config(
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
config_set(app.state.OPENAI_API_KEY, form_data.key)
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
}
@@ -176,7 +184,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)}
@app.post("/size/update")
@@ -185,9 +193,9 @@ async def update_image_size(
):
pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size
config_set(app.state.IMAGE_SIZE, form_data.size)
return {
"IMAGE_SIZE": app.state.IMAGE_SIZE,
"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE),
"status": True,
}
else:
@@ -203,7 +211,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)}
@app.post("/steps/update")
@@ -211,9 +219,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
):
if form_data.steps >= 0:
app.state.IMAGE_STEPS = form_data.steps
config_set(app.state.IMAGE_STEPS, form_data.steps)
return {
"IMAGE_STEPS": app.state.IMAGE_STEPS,
"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS),
"status": True,
}
else:
@@ -263,15 +271,25 @@ def get_models(user=Depends(get_current_user)):
async def get_default_model(user=Depends(get_admin_user)):
try:
if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
return {
"model": (
config_get(app.state.MODEL)
if config_get(app.state.MODEL)
else "dall-e-2"
)
}
elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""}
return {
"model": (
config_get(app.state.MODEL) if config_get(app.state.MODEL) else ""
)
}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
app.state.ENABLED = False
config_set(app.state.ENABLED, False)
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -280,12 +298,9 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
if app.state.ENGINE == "openai":
app.state.MODEL = model
return app.state.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
if app.state.ENGINE in ["openai", "comfyui"]:
config_set(app.state.MODEL, model)
return config_get(app.state.MODEL)
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
@@ -382,7 +397,7 @@ def generate_image(
user=Depends(get_current_user),
):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x")))
r = None
try:
@@ -396,7 +411,11 @@ def generate_image(
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
"prompt": form_data.prompt,
"n": form_data.n,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
"size": (
form_data.size
if form_data.size
else config_get(app.state.IMAGE_SIZE)
),
"response_format": "b64_json",
}
@@ -430,19 +449,19 @@ def generate_image(
"n": form_data.n,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if config_get(app.state.IMAGE_STEPS) is not None:
data["steps"] = config_get(app.state.IMAGE_STEPS)
if form_data.negative_prompt != None:
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
app.state.MODEL,
config_get(app.state.MODEL),
data,
user.id,
app.state.COMFYUI_BASE_URL,
config_get(app.state.COMFYUI_BASE_URL),
)
log.debug(f"res: {res}")
@@ -469,10 +488,10 @@ def generate_image(
"height": height,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if config_get(app.state.IMAGE_STEPS) is not None:
data["steps"] = config_get(app.state.IMAGE_STEPS)
if form_data.negative_prompt != None:
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
r = requests.post(