feat: switch to config proxy, remove config_get/set

This commit is contained in:
Jun Siang Cheah
2024-05-10 15:03:24 +08:00
parent f712c90019
commit 298e6848b3
11 changed files with 340 additions and 379 deletions

View File

@@ -42,8 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
config_get,
config_set,
AppConfig,
)
@@ -62,28 +61,30 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.ENGINE = IMAGE_GENERATION_ENGINE
app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.config = AppConfig()
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
app.state.MODEL = IMAGE_GENERATION_MODEL
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = IMAGE_SIZE
app.state.IMAGE_STEPS = IMAGE_STEPS
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"engine": config_get(app.state.ENGINE),
"enabled": config_get(app.state.ENABLED),
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
@@ -94,11 +95,11 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
config_set(app.state.ENGINE, form_data.engine)
config_set(app.state.ENABLED, form_data.enabled)
app.state.config.ENGINE = form_data.engine
app.state.config.ENABLED = form_data.enabled
return {
"engine": config_get(app.state.ENGINE),
"enabled": config_get(app.state.ENABLED),
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
@@ -110,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)):
return {
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
}
@@ -121,29 +122,29 @@ async def update_engine_url(
):
if form_data.AUTOMATIC1111_BASE_URL == None:
config_set(app.state.AUTOMATIC1111_BASE_URL, config_get(AUTOMATIC1111_BASE_URL))
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try:
r = requests.head(url)
config_set(app.state.AUTOMATIC1111_BASE_URL, url)
app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
config_set(app.state.COMFYUI_BASE_URL, COMFYUI_BASE_URL)
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
config_set(app.state.COMFYUI_BASE_URL, url)
app.state.config.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": config_get(app.state.AUTOMATIC1111_BASE_URL),
"COMFYUI_BASE_URL": config_get(app.state.COMFYUI_BASE_URL),
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True,
}
@@ -156,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
}
@@ -168,13 +169,13 @@ async def update_openai_config(
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
config_set(app.state.OPENAI_API_BASE_URL, form_data.url)
config_set(app.state.OPENAI_API_KEY, form_data.key)
app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.config.OPENAI_API_KEY = form_data.key
return {
"status": True,
"OPENAI_API_BASE_URL": config_get(app.state.OPENAI_API_BASE_URL),
"OPENAI_API_KEY": config_get(app.state.OPENAI_API_KEY),
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
}
@@ -184,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE)}
return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
@app.post("/size/update")
@@ -193,9 +194,9 @@ async def update_image_size(
):
pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size):
config_set(app.state.IMAGE_SIZE, form_data.size)
app.state.config.IMAGE_SIZE = form_data.size
return {
"IMAGE_SIZE": config_get(app.state.IMAGE_SIZE),
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"status": True,
}
else:
@@ -211,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS)}
return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
@app.post("/steps/update")
@@ -219,9 +220,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
):
if form_data.steps >= 0:
config_set(app.state.IMAGE_STEPS, form_data.steps)
app.state.config.IMAGE_STEPS = form_data.steps
return {
"IMAGE_STEPS": config_get(app.state.IMAGE_STEPS),
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"status": True,
}
else:
@@ -234,14 +235,14 @@ async def update_image_size(
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
if app.state.ENGINE == "openai":
if app.state.config.ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif app.state.ENGINE == "comfyui":
elif app.state.config.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
info = r.json()
return list(
@@ -253,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else:
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
)
models = r.json()
return list(
@@ -263,33 +264,29 @@ def get_models(user=Depends(get_current_user)):
)
)
except Exception as e:
app.state.ENABLED = False
app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
if app.state.ENGINE == "openai":
if app.state.config.ENGINE == "openai":
return {
"model": (
config_get(app.state.MODEL)
if config_get(app.state.MODEL)
else "dall-e-2"
)
}
elif app.state.ENGINE == "comfyui":
return {
"model": (
config_get(app.state.MODEL) if config_get(app.state.MODEL) else ""
app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
)
}
elif app.state.config.ENGINE == "comfyui":
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
config_set(app.state.ENABLED, False)
app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@@ -298,17 +295,20 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
if app.state.ENGINE in ["openai", "comfyui"]:
config_set(app.state.MODEL, model)
return config_get(app.state.MODEL)
if app.state.config.ENGINE in ["openai", "comfyui"]:
app.state.config.MODEL = model
return app.state.config.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
)
return options
@@ -397,30 +397,32 @@ def generate_image(
user=Depends(get_current_user),
):
width, height = tuple(map(int, config_get(app.state.IMAGE_SIZE).split("x")))
width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
r = None
try:
if app.state.ENGINE == "openai":
if app.state.config.ENGINE == "openai":
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
"model": (
app.state.config.MODEL
if app.state.config.MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt,
"n": form_data.n,
"size": (
form_data.size
if form_data.size
else config_get(app.state.IMAGE_SIZE)
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
),
"response_format": "b64_json",
}
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data,
headers=headers,
)
@@ -440,7 +442,7 @@ def generate_image(
return images
elif app.state.ENGINE == "comfyui":
elif app.state.config.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
@@ -449,8 +451,8 @@ def generate_image(
"n": form_data.n,
}
if config_get(app.state.IMAGE_STEPS) is not None:
data["steps"] = config_get(app.state.IMAGE_STEPS)
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
@@ -458,10 +460,10 @@ def generate_image(
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
config_get(app.state.MODEL),
app.state.config.MODEL,
data,
user.id,
config_get(app.state.COMFYUI_BASE_URL),
app.state.config.COMFYUI_BASE_URL,
)
log.debug(f"res: {res}")
@@ -488,14 +490,14 @@ def generate_image(
"height": height,
}
if config_get(app.state.IMAGE_STEPS) is not None:
data["steps"] = config_get(app.state.IMAGE_STEPS)
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)