More options for AUTOMATIC1111

This commit adds 3 new options to the AUTOMATIC1111 settings:
- CFG Scale
- Sampler
- Scheduler

These options allow users to configure these parameters directly through the admin settings, without needing to modify the source code, which was previously required to change the default values in  AUTOMATIC1111.

Signed-off-by: Balazs Toldi <balazs@toldi.eu>
This commit is contained in:
Balazs Toldi
2024-09-07 17:21:17 +02:00
parent 2544f7eaf0
commit 7f6dae41f0
4 changed files with 119 additions and 0 deletions

View File

@@ -17,6 +17,9 @@ from open_webui.apps.images.utils.comfyui import (
from open_webui.config import (
AUTOMATIC1111_API_AUTH,
AUTOMATIC1111_BASE_URL,
AUTOMATIC1111_CFG_SCALE,
AUTOMATIC1111_SAMPLER,
AUTOMATIC1111_SCHEDULER,
CACHE_DIR,
COMFYUI_BASE_URL,
COMFYUI_WORKFLOW,
@@ -65,6 +68,9 @@ app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
@@ -85,6 +91,9 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
@@ -102,6 +111,9 @@ class OpenAIConfigForm(BaseModel):
class Automatic1111ConfigForm(BaseModel):
AUTOMATIC1111_BASE_URL: str
AUTOMATIC1111_API_AUTH: str
AUTOMATIC1111_CFG_SCALE: float
AUTOMATIC1111_SAMPLER: str
AUTOMATIC1111_SCHEDULER: str
class ComfyUIConfigForm(BaseModel):
@@ -132,6 +144,12 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
app.state.config.AUTOMATIC1111_API_AUTH = (
form_data.automatic1111.AUTOMATIC1111_API_AUTH
)
app.state.config.AUTOMATIC1111_CFG_SCALE = form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
app.state.config.AUTOMATIC1111_SAMPLER = form_data.automatic1111.AUTOMATIC1111_SAMPLER
app.state.config.AUTOMATIC1111_SCHEDULER = (
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
)
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
@@ -147,6 +165,9 @@ async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
"automatic1111": {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
},
"comfyui": {
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
@@ -266,6 +287,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
)
return {
"MODEL": app.state.config.MODEL,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
@@ -523,6 +545,15 @@ async def image_generations(
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if app.state.config.AUTOMATIC1111_CFG_SCALE:
data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
if app.state.config.AUTOMATIC1111_SAMPLER:
data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
if app.state.config.AUTOMATIC1111_SCHEDULER:
data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(