feat: config.json db migration

This commit is contained in:
Timothy J. Baek
2024-08-25 16:52:36 +02:00
parent 072945c40b
commit 58cf1be20c
16 changed files with 432 additions and 322 deletions

View File

@@ -1,13 +1,17 @@
from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func
from contextlib import contextmanager
import os
import sys
import logging
import importlib.metadata
import pkgutil
from urllib.parse import urlparse
from datetime import datetime
import chromadb
from chromadb import Settings
from bs4 import BeautifulSoup
from typing import TypeVar, Generic
from pydantic import BaseModel
from typing import Optional
@@ -16,68 +20,39 @@ from pathlib import Path
import json
import yaml
import markdown
import requests
import shutil
from apps.webui.internal.db import Base, get_db
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
BACKEND_DIR = Path(__file__).parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BASE_DIR)
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
except ImportError:
print("dotenv not installed, skipping...")
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
from env import (
ENV,
VERSION,
SAFE_MODE,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
BASE_DIR,
DATA_DIR,
BACKEND_DIR,
FRONTEND_BUILD_DIR,
WEBUI_NAME,
WEBUI_URL,
WEBUI_FAVICON_URL,
WEBUI_BUILD_HASH,
CONFIG_DATA,
DATABASE_URL,
CHANGELOG,
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
log,
)
class EndpointFilter(logging.Filter):
@@ -88,135 +63,60 @@ class EndpointFilter(logging.Filter):
# Filter out /endpoint
logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
####################################
# ENV (dev,test,prod)
####################################
ENV = os.environ.get("ENV", "dev")
try:
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
except Exception:
try:
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
except importlib.metadata.PackageNotFoundError:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# WEBUI_BUILD_HASH
####################################
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
####################################
# DATA/FRONTEND BUILD DIR
####################################
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except Exception:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except Exception:
CONFIG_DATA = {}
####################################
# Config helpers
####################################
# Function to run the alembic migrations
def run_migrations():
print("Running migrations")
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
run_migrations()
class Config(Base):
__tablename__ = "config"
id = Column(Integer, primary_key=True)
data = Column(JSON, nullable=False)
version = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, nullable=False, server_default=func.now())
updated_at = Column(DateTime, nullable=True, onupdate=func.now())
def load_initial_config():
with open(f"{DATA_DIR}/config.json", "r") as file:
return json.load(file)
def save_to_db(data):
with get_db() as db:
existing_config = db.query(Config).first()
if not existing_config:
new_config = Config(data=data, version=0)
db.add(new_config)
else:
existing_config.data = data
db.commit()
# When initializing, check if config.json exists and migrate it to the database
if os.path.exists(f"{DATA_DIR}/config.json"):
data = load_initial_config()
save_to_db(data)
os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json")
def save_config():
try:
with open(f"{DATA_DIR}/config.json", "w") as f:
@@ -244,9 +144,9 @@ class PersistentConfig(Generic[T]):
self.env_name = env_name
self.config_path = config_path
self.env_value = env_value
self.config_value = get_config_value(config_path)
self.config_value = self.load_latest_config_value(config_path)
if self.config_value is not None:
log.info(f"'{env_name}' loaded from config.json")
log.info(f"'{env_name}' loaded from the latest database entry")
self.value = self.config_value
else:
self.value = env_value
@@ -254,33 +154,44 @@ class PersistentConfig(Generic[T]):
def __str__(self):
return str(self.value)
@property
def __dict__(self):
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
def __getattribute__(self, item):
if item == "__dict__":
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
return super().__getattribute__(item)
def load_latest_config_value(self, config_path: str):
with get_db() as db:
config_entry = db.query(Config).order_by(Config.id.desc()).first()
if config_entry:
try:
path_parts = config_path.split(".")
config_value = config_entry.data
for key in path_parts:
config_value = config_value[key]
return config_value
except KeyError:
return None
def save(self):
# Don't save if the value is the same as the env value and the config value
if self.env_value == self.value:
if self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to config.json")
if self.env_value == self.value and self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to the database")
path_parts = self.config_path.split(".")
config = CONFIG_DATA
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
save_config()
with get_db() as db:
existing_config = db.query(Config).first()
if existing_config:
config = existing_config.data
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
existing_config.version += 1
else: # This case should not actually occur as there should always be at least one entry
new_data = {}
config = new_data
for key in path_parts[:-1]:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
new_config = Config(data=new_data, version=1)
db.add(new_config)
db.commit()
self.config_value = self.value
@@ -305,11 +216,6 @@ class AppConfig:
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
@@ -999,30 +905,6 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
)
####################################
# WEBUI_SECRET_KEY
####################################
WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get(
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
@@ -1553,14 +1435,3 @@ AUDIO_TTS_VOICE = PersistentConfig(
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice
)
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")