mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-18 05:05:09 +02:00
refac
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import ssl as _stdlib_ssl
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
@@ -38,31 +36,17 @@ from typing_extensions import Self
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSLParams:
|
||||
"""SSL parameters extracted from a PostgreSQL ``DATABASE_URL``.
|
||||
|
||||
Holds the connection-mode flag and optional certificate file paths
|
||||
so that each driver (asyncpg, psycopg2/libpq) can receive them in
|
||||
the format it expects.
|
||||
"""
|
||||
|
||||
mode: str | None = None
|
||||
rootcert: str | None = None
|
||||
cert: str | None = None
|
||||
key: str | None = None
|
||||
crl: str | None = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.mode is not None
|
||||
|
||||
@property
|
||||
def has_any(self) -> bool:
|
||||
"""True when *any* SSL-related field is set (mode or cert files)."""
|
||||
return any((self.mode, self.rootcert, self.cert, self.key, self.crl))
|
||||
|
||||
|
||||
# ── URL extraction / reattachment ────────────────────────────────────
|
||||
# ── SSL URL normalization (used by sync engine & Alembic migrations) ─
|
||||
#
|
||||
# psycopg2 (sync) needs ``sslmode=`` in the connection string (it does
|
||||
# not recognise the bare ``ssl=`` key that some ORMs emit). The helpers
|
||||
# below strip all SSL-related query params, normalise them, and
|
||||
# reattach them in the canonical libpq form.
|
||||
#
|
||||
# The **async** engine now uses psycopg (v3), which speaks libpq
|
||||
# natively, so it needs no translation at all — the DATABASE_URL is
|
||||
# passed through as-is.
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _pop_first(params: dict[str, list[str]], key: str) -> str | None:
|
||||
@@ -71,63 +55,57 @@ def _pop_first(params: dict[str, list[str]], key: str) -> str | None:
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def extract_ssl_params_from_url(url: str) -> tuple[str, SSLParams]:
|
||||
"""Strip all SSL query-string parameters from a PostgreSQL URL.
|
||||
|
||||
asyncpg does not accept libpq-style certificate-file keys
|
||||
(``sslrootcert``, ``sslcert``, ``sslkey``, ``sslcrl``), so every
|
||||
SSL-related key is removed and returned as a structured
|
||||
:class:`SSLParams` object.
|
||||
|
||||
Returns ``(url_without_ssl, ssl_params)``. Non-PostgreSQL URLs are
|
||||
returned unchanged with an empty ``SSLParams``.
|
||||
"""
|
||||
if not url or not any(
|
||||
def _is_postgres_url(url: str) -> bool:
|
||||
"""Return True if *url* looks like a PostgreSQL connection string."""
|
||||
return bool(url) and any(
|
||||
url.startswith(p) for p in ('postgresql://', 'postgresql+', 'postgres://')
|
||||
):
|
||||
return url, SSLParams()
|
||||
)
|
||||
|
||||
|
||||
def extract_ssl_params_from_url(url: str) -> tuple[str, dict[str, str]]:
|
||||
"""Strip SSL query-string parameters from a PostgreSQL URL.
|
||||
|
||||
Returns ``(url_without_ssl, ssl_dict)`` where *ssl_dict* maps
|
||||
canonical libpq key names (``sslmode``, ``sslrootcert``, …) to
|
||||
their values. Non-PostgreSQL URLs are returned unchanged with an
|
||||
empty dict.
|
||||
"""
|
||||
if not _is_postgres_url(url):
|
||||
return url, {}
|
||||
|
||||
parsed = urlparse(url)
|
||||
qp = parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
# Prefer sslmode (libpq canonical) over the asyncpg-only ``ssl`` key.
|
||||
# Both must be popped unconditionally so neither leaks into the cleaned URL.
|
||||
# Prefer sslmode (libpq canonical) over the bare ``ssl`` key.
|
||||
sslmode_val = _pop_first(qp, 'sslmode')
|
||||
ssl_val = _pop_first(qp, 'ssl')
|
||||
ssl_mode = sslmode_val or ssl_val
|
||||
|
||||
params = SSLParams(
|
||||
mode=ssl_mode,
|
||||
rootcert=_pop_first(qp, 'sslrootcert'),
|
||||
cert=_pop_first(qp, 'sslcert'),
|
||||
key=_pop_first(qp, 'sslkey'),
|
||||
crl=_pop_first(qp, 'sslcrl'),
|
||||
)
|
||||
ssl_dict: dict[str, str] = {}
|
||||
if ssl_mode:
|
||||
ssl_dict['sslmode'] = ssl_mode
|
||||
for key in ('sslrootcert', 'sslcert', 'sslkey', 'sslcrl'):
|
||||
val = _pop_first(qp, key)
|
||||
if val:
|
||||
ssl_dict[key] = val
|
||||
|
||||
if not params.has_any:
|
||||
return url, params
|
||||
if not ssl_dict:
|
||||
return url, ssl_dict
|
||||
|
||||
cleaned_query = urlencode(qp, doseq=True)
|
||||
return urlunparse(parsed._replace(query=cleaned_query)), params
|
||||
return urlunparse(parsed._replace(query=cleaned_query)), ssl_dict
|
||||
|
||||
|
||||
def reattach_ssl_params_to_url(url_without_ssl: str, ssl_params: SSLParams) -> str:
|
||||
def reattach_ssl_params_to_url(url_without_ssl: str, ssl_dict: dict[str, str]) -> str:
|
||||
"""Re-append SSL query-string parameters to a cleaned PostgreSQL URL.
|
||||
|
||||
Used for psycopg2/libpq consumers that expect ``sslmode`` and the
|
||||
certificate-file keys in the connection string.
|
||||
"""
|
||||
if not ssl_params:
|
||||
if not ssl_dict:
|
||||
return url_without_ssl
|
||||
|
||||
mapping = (
|
||||
('sslmode', ssl_params.mode),
|
||||
('sslrootcert', ssl_params.rootcert),
|
||||
('sslcert', ssl_params.cert),
|
||||
('sslkey', ssl_params.key),
|
||||
('sslcrl', ssl_params.crl),
|
||||
)
|
||||
parts = [f'{k}={v}' for k, v in mapping if v]
|
||||
parts = [f'{k}={v}' for k, v in ssl_dict.items() if v]
|
||||
if not parts:
|
||||
return url_without_ssl
|
||||
|
||||
@@ -135,54 +113,6 @@ def reattach_ssl_params_to_url(url_without_ssl: str, ssl_params: SSLParams) -> s
|
||||
return f'{url_without_ssl}{sep}{"&".join(parts)}'
|
||||
|
||||
|
||||
# ── asyncpg SSLContext builder ───────────────────────────────────────
|
||||
|
||||
|
||||
def _make_ssl_context(ssl_params: SSLParams, *, verify: bool) -> _stdlib_ssl.SSLContext:
|
||||
"""Create an :class:`ssl.SSLContext` from *ssl_params*.
|
||||
|
||||
When *verify* is ``False``, hostname checking and certificate
|
||||
verification are disabled (matching libpq ``require`` semantics).
|
||||
"""
|
||||
ctx = _stdlib_ssl.create_default_context(cafile=ssl_params.rootcert)
|
||||
if not verify:
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = _stdlib_ssl.CERT_NONE
|
||||
if ssl_params.cert and ssl_params.key:
|
||||
ctx.load_cert_chain(certfile=ssl_params.cert, keyfile=ssl_params.key)
|
||||
if verify and ssl_params.crl:
|
||||
ctx.load_verify_locations(cafile=ssl_params.crl)
|
||||
ctx.verify_flags |= _stdlib_ssl.VERIFY_CRL_CHECK_LEAF
|
||||
return ctx
|
||||
|
||||
|
||||
def build_asyncpg_ssl_args(ssl_params: SSLParams) -> dict:
|
||||
"""Convert :class:`SSLParams` to asyncpg-compatible ``connect_args``.
|
||||
|
||||
Returns a dict suitable for unpacking into
|
||||
``create_async_engine(...)``.
|
||||
"""
|
||||
if not ssl_params:
|
||||
return {}
|
||||
|
||||
mode = (ssl_params.mode or 'require').lower()
|
||||
|
||||
if mode == 'disable':
|
||||
return {'connect_args': {'ssl': False}}
|
||||
if mode in ('allow', 'prefer'):
|
||||
return {}
|
||||
if mode == 'require':
|
||||
return {'connect_args': {'ssl': _make_ssl_context(ssl_params, verify=False)}}
|
||||
if mode in ('verify-ca', 'verify-full'):
|
||||
ctx = _make_ssl_context(ssl_params, verify=True)
|
||||
if mode == 'verify-ca':
|
||||
ctx.check_hostname = False
|
||||
return {'connect_args': {'ssl': ctx}}
|
||||
|
||||
# Unknown value — pass through as-is and let asyncpg decide.
|
||||
return {'connect_args': {'ssl': ssl_params.mode}}
|
||||
|
||||
|
||||
# Backwards-compatible aliases for external callers.
|
||||
extract_ssl_mode_from_url = extract_ssl_params_from_url
|
||||
reattach_ssl_mode_to_url = reattach_ssl_params_to_url
|
||||
@@ -245,32 +175,38 @@ if ENABLE_DB_MIGRATIONS:
|
||||
handle_peewee_migration(DATABASE_URL)
|
||||
|
||||
|
||||
# Normalize SSL params from the URL once; each engine branch re-injects
|
||||
# the driver-appropriate form.
|
||||
DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_PARAMS = extract_ssl_params_from_url(DATABASE_URL)
|
||||
# Normalize SSL params from the URL once; the sync engine needs them
|
||||
# reattached in canonical libpq form for psycopg2.
|
||||
_url_without_ssl, _ssl_dict = extract_ssl_params_from_url(DATABASE_URL)
|
||||
|
||||
# For psycopg2 (sync engine), re-append sslmode + cert-file params.
|
||||
SQLALCHEMY_DATABASE_URL = (
|
||||
reattach_ssl_params_to_url(DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_PARAMS) if DATABASE_SSL_PARAMS else DATABASE_URL
|
||||
reattach_ssl_params_to_url(_url_without_ssl, _ssl_dict) if _ssl_dict else DATABASE_URL
|
||||
)
|
||||
|
||||
|
||||
def _make_async_url(url: str) -> str:
|
||||
"""Convert a sync database URL to its async driver equivalent."""
|
||||
"""Convert a sync database URL to its async driver equivalent.
|
||||
|
||||
The async engine uses psycopg (v3) which speaks libpq natively,
|
||||
so all standard connection-string parameters (``sslmode``,
|
||||
``options``, ``target_session_attrs``, etc.) are passed through
|
||||
without any translation.
|
||||
"""
|
||||
if url.startswith('sqlite+sqlcipher://'):
|
||||
# SQLCipher has no async driver — not supported for async
|
||||
raise ValueError(
|
||||
'sqlite+sqlcipher:// URLs are not supported with async engine. '
|
||||
'Use standard sqlite:// or postgresql:// instead.'
|
||||
)
|
||||
if url.startswith('sqlite:///') or url.startswith('sqlite://'):
|
||||
return url.replace('sqlite://', 'sqlite+aiosqlite://', 1)
|
||||
# psycopg v3 — auto-selects async mode with create_async_engine
|
||||
if url.startswith('postgresql+psycopg2://'):
|
||||
return url.replace('postgresql+psycopg2://', 'postgresql+asyncpg://', 1)
|
||||
return url.replace('postgresql+psycopg2://', 'postgresql+psycopg://', 1)
|
||||
if url.startswith('postgresql://'):
|
||||
return url.replace('postgresql://', 'postgresql+asyncpg://', 1)
|
||||
return url.replace('postgresql://', 'postgresql+psycopg://', 1)
|
||||
if url.startswith('postgres://'):
|
||||
return url.replace('postgres://', 'postgresql+asyncpg://', 1)
|
||||
return url.replace('postgres://', 'postgresql+psycopg://', 1)
|
||||
# For other dialects, return as-is and let SQLAlchemy handle it
|
||||
return url
|
||||
|
||||
@@ -395,10 +331,10 @@ get_db = contextmanager(get_session)
|
||||
# ASYNC ENGINE (used for ALL runtime database operations)
|
||||
# ============================================================
|
||||
|
||||
# Use the SSL-stripped URL for asyncpg — SSL is injected via connect_args.
|
||||
ASYNC_SQLALCHEMY_DATABASE_URL = _make_async_url(
|
||||
DATABASE_URL_WITHOUT_SSL if DATABASE_SSL_PARAMS else SQLALCHEMY_DATABASE_URL
|
||||
)
|
||||
# psycopg (v3) speaks libpq natively — the full DATABASE_URL is passed
|
||||
# through as-is. SSL params, ``options``, ``target_session_attrs``, etc.
|
||||
# all work without any stripping or translation.
|
||||
ASYNC_SQLALCHEMY_DATABASE_URL = _make_async_url(SQLALCHEMY_DATABASE_URL)
|
||||
|
||||
if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL:
|
||||
# Generous default — async coroutines + no session sharing = high connection demand.
|
||||
@@ -416,10 +352,6 @@ if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL:
|
||||
def _set_sqlite_pragmas(dbapi_connection, connection_record):
|
||||
_apply_sqlite_pragmas(dbapi_connection)
|
||||
else:
|
||||
# Inject asyncpg-compatible SSL connect_args when the user specified
|
||||
# sslmode/ssl in DATABASE_URL.
|
||||
asyncpg_ssl_args = build_asyncpg_ssl_args(DATABASE_SSL_PARAMS)
|
||||
|
||||
if isinstance(DATABASE_POOL_SIZE, int):
|
||||
if DATABASE_POOL_SIZE > 0:
|
||||
async_engine = create_async_engine(
|
||||
@@ -429,20 +361,17 @@ else:
|
||||
pool_timeout=DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=DATABASE_POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
**asyncpg_ssl_args,
|
||||
)
|
||||
else:
|
||||
async_engine = create_async_engine(
|
||||
ASYNC_SQLALCHEMY_DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
poolclass=NullPool,
|
||||
**asyncpg_ssl_args,
|
||||
)
|
||||
else:
|
||||
async_engine = create_async_engine(
|
||||
ASYNC_SQLALCHEMY_DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
**asyncpg_ssl_args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ target_metadata = Auth.metadata
|
||||
|
||||
DB_URL = DATABASE_URL
|
||||
|
||||
# Normalize SSL query params for psycopg2 (Alembic uses psycopg2, not asyncpg).
|
||||
# Normalize SSL query params for psycopg2 (Alembic uses psycopg2 for sync migrations).
|
||||
url_without_ssl, ssl_params = extract_ssl_params_from_url(DB_URL)
|
||||
DB_URL = reattach_ssl_params_to_url(url_without_ssl, ssl_params) if ssl_params else DB_URL
|
||||
|
||||
|
||||
@@ -391,6 +391,56 @@ class ModelsTable:
|
||||
|
||||
return ModelListResponse(items=models, total=total)
|
||||
|
||||
async def get_model_meta_by_id(
|
||||
self, id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[tuple[dict, int]]:
|
||||
"""Return (meta, updated_at) for a model, skipping access grant resolution."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Model.meta, Model.updated_at).filter_by(id=id)
|
||||
)
|
||||
return result.first()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_all_tags(
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool = False,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> set[str]:
|
||||
"""Extract unique tag names from model meta, querying only the meta column."""
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Model.meta).filter(Model.base_model_id != None)
|
||||
|
||||
if not is_admin:
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_group_ids = [group.id for group in user_groups]
|
||||
|
||||
filter_dict = {'user_id': user_id}
|
||||
if user_group_ids:
|
||||
filter_dict['group_ids'] = user_group_ids
|
||||
|
||||
stmt = self._has_permission(db, stmt, filter_dict, permission='read')
|
||||
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
|
||||
tags_set: set[str] = set()
|
||||
for meta in rows:
|
||||
if not meta:
|
||||
continue
|
||||
for tag in meta.get('tags', []):
|
||||
try:
|
||||
name = tag.get('name') if isinstance(tag, dict) else str(tag)
|
||||
if name:
|
||||
tags_set.add(name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return tags_set
|
||||
|
||||
async def get_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
|
||||
@@ -138,18 +138,25 @@ async def get_models(
|
||||
db=db,
|
||||
)
|
||||
|
||||
return ModelAccessListResponse(
|
||||
items=[
|
||||
# Strip profile_image_url from meta — images are served via /model/profile/image.
|
||||
items = []
|
||||
for model in result.items:
|
||||
data = model.model_dump()
|
||||
if data.get('meta'):
|
||||
data['meta'].pop('profile_image_url', None)
|
||||
items.append(
|
||||
ModelAccessResponse(
|
||||
**model.model_dump(),
|
||||
**data,
|
||||
write_access=(
|
||||
(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == model.user_id
|
||||
or model.id in writable_model_ids
|
||||
),
|
||||
)
|
||||
for model in result.items
|
||||
],
|
||||
)
|
||||
|
||||
return ModelAccessListResponse(
|
||||
items=items,
|
||||
total=result.total,
|
||||
)
|
||||
|
||||
@@ -171,25 +178,12 @@ async def get_base_models(user=Depends(get_admin_user), db: AsyncSession = Depen
|
||||
|
||||
@router.get('/tags', response_model=list[str])
|
||||
async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
models = await Models.get_models(db=db)
|
||||
else:
|
||||
models = await Models.get_models_by_user_id(user.id, db=db)
|
||||
|
||||
tags_set = set()
|
||||
for model in models:
|
||||
if model.meta:
|
||||
meta = model.meta.model_dump()
|
||||
for tag in meta.get('tags', []):
|
||||
try:
|
||||
name = tag.get('name') if isinstance(tag, dict) else str(tag)
|
||||
if name:
|
||||
tags_set.add(name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
tags = sorted(tags_set)
|
||||
return tags
|
||||
tags = await Models.get_all_tags(
|
||||
user_id=user.id,
|
||||
is_admin=(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL),
|
||||
db=db,
|
||||
)
|
||||
return sorted(tags)
|
||||
|
||||
|
||||
############################
|
||||
@@ -466,54 +460,48 @@ async def get_model_profile_image(
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
model = await Models.get_model_by_id(id, db=db)
|
||||
model_meta = await Models.get_model_meta_by_id(id, db=db)
|
||||
|
||||
if model:
|
||||
etag = f'"{model.updated_at}"' if model.updated_at else None
|
||||
if model_meta:
|
||||
meta, updated_at = model_meta
|
||||
profile_image_url = (meta or {}).get('profile_image_url')
|
||||
|
||||
if model.meta.profile_image_url:
|
||||
if model.meta.profile_image_url.startswith('http'):
|
||||
if profile_image_url:
|
||||
if profile_image_url.startswith('http'):
|
||||
return Response(
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
headers={'Location': model.meta.profile_image_url},
|
||||
headers={'Location': profile_image_url},
|
||||
)
|
||||
elif model.meta.profile_image_url.startswith('data:image'):
|
||||
elif profile_image_url.startswith('data:image'):
|
||||
try:
|
||||
header, base64_data = model.meta.profile_image_url.split(',', 1)
|
||||
header, base64_data = profile_image_url.split(',', 1)
|
||||
image_data = base64.b64decode(base64_data)
|
||||
image_buffer = io.BytesIO(image_data)
|
||||
media_type = header.split(';')[0].lstrip('data:')
|
||||
|
||||
headers = {'Content-Disposition': 'inline'}
|
||||
if etag:
|
||||
headers['ETag'] = etag
|
||||
if updated_at:
|
||||
headers['ETag'] = f'"{updated_at}"'
|
||||
|
||||
return StreamingResponse(
|
||||
image_buffer,
|
||||
media_type=media_type,
|
||||
headers=headers,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
safe_static = _safe_static_redirect_path(model.meta.profile_image_url)
|
||||
safe_static = _safe_static_redirect_path(profile_image_url)
|
||||
if safe_static:
|
||||
return RedirectResponse(
|
||||
url=safe_static,
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Canonical URL so browsers cache one asset for all default model avatars
|
||||
# (distinct /profile/image?id=... URLs would otherwise re-download the same bytes).
|
||||
return RedirectResponse(
|
||||
url='/static/favicon.png',
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
url='/static/favicon.png',
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
return RedirectResponse(
|
||||
url='/static/favicon.png',
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
||||
@@ -14,7 +14,7 @@ async def get_function_module(request, function_id, load_from_db=True):
|
||||
"""
|
||||
Get the function module by its ID.
|
||||
"""
|
||||
function_module, _, _ = await get_function_module_from_cache(request, function_id, load_from_db)
|
||||
function_module, _, _ = await get_function_module_from_cache(request, function_id, load_from_db=load_from_db)
|
||||
return function_module
|
||||
|
||||
|
||||
|
||||
@@ -287,9 +287,9 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||
# imported/custom model configs may reference tools or filters the user
|
||||
# hasn't installed, and trying to load those would cause persistent
|
||||
# "Failed to load function module" log spam on every model refresh.
|
||||
for function_id in functions_by_id:
|
||||
for function_id, function in functions_by_id.items():
|
||||
try:
|
||||
await get_function_module_from_cache(request, function_id)
|
||||
await get_function_module_from_cache(request, function_id, function=function)
|
||||
except Exception as e:
|
||||
log.debug(f'Failed to load function module for {function_id}: {e}')
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from open_webui.env import (
|
||||
OFFLINE_MODE,
|
||||
ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS,
|
||||
)
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.functions import FunctionModel, Functions
|
||||
from open_webui.models.tools import Tools
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -335,13 +335,14 @@ async def get_tool_module_from_cache(request, tool_id, load_from_db=True):
|
||||
return tool_module, frontmatter
|
||||
|
||||
|
||||
async def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||
async def get_function_module_from_cache(request, function_id, function: FunctionModel | None = None, load_from_db=True):
|
||||
if load_from_db:
|
||||
# Always load from the database by default
|
||||
# This is useful for hooks like "inlet" or "outlet" where the content might change
|
||||
# and we want to ensure the latest content is used.
|
||||
|
||||
function = await Functions.get_function_by_id(function_id)
|
||||
if function is None:
|
||||
function = await Functions.get_function_by_id(function_id)
|
||||
if not function:
|
||||
raise Exception(f'Function not found: {function_id}')
|
||||
content = function.content
|
||||
|
||||
@@ -28,7 +28,7 @@ starsessions[redis]==2.2.1
|
||||
|
||||
sqlalchemy==2.0.48
|
||||
aiosqlite==0.21.0
|
||||
asyncpg==0.30.0
|
||||
psycopg[binary]==3.2.9
|
||||
alembic==1.18.4
|
||||
peewee==3.19.0
|
||||
peewee-migrate==1.14.3
|
||||
|
||||
@@ -26,7 +26,7 @@ python-mimeparse==2.0.0
|
||||
|
||||
sqlalchemy[asyncio]==2.0.48
|
||||
aiosqlite==0.21.0
|
||||
asyncpg==0.30.0
|
||||
psycopg[binary]==3.2.9
|
||||
alembic==1.18.4
|
||||
peewee==3.19.0
|
||||
peewee-migrate==1.14.3
|
||||
|
||||
@@ -34,7 +34,7 @@ dependencies = [
|
||||
|
||||
"sqlalchemy[asyncio]==2.0.48",
|
||||
"aiosqlite==0.21.0",
|
||||
"asyncpg==0.30.0",
|
||||
"psycopg[binary]==3.2.9",
|
||||
"alembic==1.18.4",
|
||||
"peewee==3.19.0",
|
||||
"peewee-migrate==1.14.3",
|
||||
|
||||
@@ -1,27 +1,14 @@
|
||||
<script lang="ts">
|
||||
import { toast } from 'svelte-sonner';
|
||||
|
||||
import DOMPurify from 'dompurify';
|
||||
import { marked } from 'marked';
|
||||
|
||||
import { getContext, tick, onDestroy } from 'svelte';
|
||||
import { getContext, tick } from 'svelte';
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
import { chatCompletion } from '$lib/apis/openai';
|
||||
|
||||
import ChatBubble from '$lib/components/icons/ChatBubble.svelte';
|
||||
import LightBulb from '$lib/components/icons/LightBulb.svelte';
|
||||
import Markdown from '../Messages/Markdown.svelte';
|
||||
import Skeleton from '../Messages/Skeleton.svelte';
|
||||
import { chatId, models, socket } from '$lib/stores';
|
||||
|
||||
export let id = '';
|
||||
export let messageId = '';
|
||||
|
||||
export let model = null;
|
||||
export let messages = [];
|
||||
export let actions = [];
|
||||
export let onAdd = (e) => {};
|
||||
export let onSetInputText = (text) => {};
|
||||
|
||||
let floatingInput = false;
|
||||
let selectedAction = null;
|
||||
@@ -29,11 +16,6 @@
|
||||
let selectedText = '';
|
||||
let floatingInputValue = '';
|
||||
|
||||
let content = '';
|
||||
let responseContent = null;
|
||||
let responseDone = false;
|
||||
let controller = null;
|
||||
|
||||
$: if (actions.length === 0) {
|
||||
actions = DEFAULT_ACTIONS;
|
||||
}
|
||||
@@ -54,25 +36,7 @@
|
||||
}
|
||||
];
|
||||
|
||||
const autoScroll = async () => {
|
||||
const responseContainer = document.getElementById('response-container');
|
||||
if (responseContainer) {
|
||||
// Scroll to bottom only if the scroll is at the bottom give 50px buffer
|
||||
if (
|
||||
responseContainer.scrollHeight - responseContainer.clientHeight <=
|
||||
responseContainer.scrollTop + 50
|
||||
) {
|
||||
responseContainer.scrollTop = responseContainer.scrollHeight;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const actionHandler = async (actionId) => {
|
||||
if (!model) {
|
||||
toast.error($i18n.t('Model not selected'));
|
||||
return;
|
||||
}
|
||||
|
||||
const actionHandler = (actionId) => {
|
||||
let selectedContent = selectedText
|
||||
.split('\n')
|
||||
.map((line) => `> ${line}`)
|
||||
@@ -80,27 +44,20 @@
|
||||
|
||||
let selectedAction = actions.find((action) => action.id === actionId);
|
||||
if (!selectedAction) {
|
||||
toast.error($i18n.t('Action not found'));
|
||||
return;
|
||||
}
|
||||
|
||||
let prompt = selectedAction?.prompt ?? '';
|
||||
let toolIds = [];
|
||||
|
||||
// Handle: {{variableId|tool:id="toolId"}} pattern
|
||||
// This regex captures variableId and toolId from {{variableId|tool:id="toolId"}}
|
||||
const varToolPattern = /\{\{(.*?)\|tool:id="([^"]+)"\}\}/g;
|
||||
prompt = prompt.replace(varToolPattern, (match, variableId, toolId) => {
|
||||
toolIds.push(toolId);
|
||||
return variableId; // Replace with just variableId
|
||||
});
|
||||
|
||||
// legacy {{TOOL:toolId}} pattern (for backward compatibility)
|
||||
let toolIdPattern = /\{\{TOOL:([^\}]+)\}\}/g;
|
||||
let match;
|
||||
while ((match = toolIdPattern.exec(prompt)) !== null) {
|
||||
toolIds.push(match[1]);
|
||||
}
|
||||
|
||||
// Remove all TOOL placeholders from the prompt
|
||||
prompt = prompt.replace(toolIdPattern, '');
|
||||
@@ -113,131 +70,17 @@
|
||||
prompt = prompt.replace('{{CONTENT}}', selectedText);
|
||||
prompt = prompt.replace('{{SELECTED_CONTENT}}', selectedContent);
|
||||
|
||||
content = prompt;
|
||||
responseContent = '';
|
||||
|
||||
let res;
|
||||
[res, controller] = await chatCompletion(localStorage.token, {
|
||||
model: model,
|
||||
model_item: $models.find((m) => m.id === model),
|
||||
messages: [
|
||||
...messages,
|
||||
{
|
||||
role: 'user',
|
||||
content: content
|
||||
}
|
||||
].map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
})),
|
||||
...(toolIds.length > 0
|
||||
? {
|
||||
tool_ids: toolIds
|
||||
// params: {
|
||||
// function_calling: 'native'
|
||||
// }
|
||||
}
|
||||
: {}),
|
||||
|
||||
stream: true // Enable streaming
|
||||
});
|
||||
|
||||
if (res && res.ok) {
|
||||
const reader = res.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const processStream = async () => {
|
||||
while (true) {
|
||||
// Read data chunks from the response stream
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Decode the received chunk
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
|
||||
// Process lines within the chunk
|
||||
const lines = chunk.split('\n').filter((line) => line.trim() !== '');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
if (line.startsWith('data: [DONE]')) {
|
||||
responseDone = true;
|
||||
|
||||
await tick();
|
||||
autoScroll();
|
||||
continue;
|
||||
} else {
|
||||
// Parse the JSON chunk
|
||||
try {
|
||||
const data = JSON.parse(line.slice(6));
|
||||
|
||||
// Append the `content` field from the "choices" object
|
||||
if (data.choices && data.choices[0]?.delta?.content) {
|
||||
responseContent += data.choices[0].delta.content;
|
||||
|
||||
autoScroll();
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Process the stream in the background
|
||||
try {
|
||||
await processStream();
|
||||
} catch (e) {
|
||||
if (e.name !== 'AbortError') {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toast.error($i18n.t('An error occurred while fetching the explanation'));
|
||||
}
|
||||
};
|
||||
|
||||
const addHandler = async () => {
|
||||
const messages = [
|
||||
{
|
||||
role: 'user',
|
||||
content: content
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: responseContent
|
||||
}
|
||||
];
|
||||
|
||||
onAdd({
|
||||
modelId: model,
|
||||
parentId: messageId,
|
||||
messages: messages
|
||||
});
|
||||
// Prepopulate the main chat input instead of inline streaming
|
||||
onSetInputText(prompt);
|
||||
closeHandler();
|
||||
};
|
||||
|
||||
export const closeHandler = () => {
|
||||
if (controller) {
|
||||
controller.abort();
|
||||
}
|
||||
|
||||
selectedAction = null;
|
||||
selectedText = '';
|
||||
responseContent = null;
|
||||
responseDone = false;
|
||||
floatingInput = false;
|
||||
floatingInputValue = '';
|
||||
};
|
||||
|
||||
onDestroy(() => {
|
||||
if (controller) {
|
||||
controller.abort();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
@@ -245,120 +88,82 @@
|
||||
class="absolute rounded-lg mt-1 text-xs z-9999"
|
||||
style="display: none"
|
||||
>
|
||||
{#if responseContent === null}
|
||||
{#if !floatingInput}
|
||||
<div
|
||||
class="flex flex-row shrink-0 p-0.5 bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-xl shadow-xl border border-gray-100 dark:border-gray-800"
|
||||
>
|
||||
{#each actions as action}
|
||||
<button
|
||||
aria-label={action.label}
|
||||
class="px-1.5 py-[1px] hover:bg-gray-50 dark:hover:bg-gray-800 rounded-xl flex items-center gap-1 min-w-fit transition"
|
||||
on:click={async () => {
|
||||
selectedText = window.getSelection().toString();
|
||||
selectedAction = action;
|
||||
{#if !floatingInput}
|
||||
<div
|
||||
class="flex flex-row shrink-0 p-0.5 bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-xl shadow-xl border border-gray-100 dark:border-gray-800"
|
||||
>
|
||||
{#each actions as action}
|
||||
<button
|
||||
aria-label={action.label}
|
||||
class="px-1.5 py-[1px] hover:bg-gray-50 dark:hover:bg-gray-800 rounded-xl flex items-center gap-1 min-w-fit transition"
|
||||
on:click={async () => {
|
||||
selectedText = window.getSelection().toString();
|
||||
selectedAction = action;
|
||||
|
||||
if (action.prompt.includes('{{INPUT_CONTENT}}')) {
|
||||
floatingInput = true;
|
||||
floatingInputValue = '';
|
||||
if (action.prompt.includes('{{INPUT_CONTENT}}')) {
|
||||
floatingInput = true;
|
||||
floatingInputValue = '';
|
||||
|
||||
await tick();
|
||||
setTimeout(() => {
|
||||
const input = document.getElementById('floating-message-input');
|
||||
if (input) {
|
||||
input.focus();
|
||||
}
|
||||
}, 0);
|
||||
} else {
|
||||
actionHandler(action.id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{#if action.icon}
|
||||
<svelte:component this={action.icon} className="size-3 shrink-0" />
|
||||
{/if}
|
||||
<div class="shrink-0">{action.label}</div>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="py-1 flex dark:text-gray-100 bg-white dark:bg-gray-850 border border-gray-100 dark:border-gray-800 w-72 rounded-full shadow-xl"
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
id="floating-message-input"
|
||||
class="ml-5 bg-transparent outline-hidden w-full flex-1 text-sm"
|
||||
placeholder={$i18n.t('Ask a question')}
|
||||
aria-label={$i18n.t('Ask a question')}
|
||||
bind:value={floatingInputValue}
|
||||
on:keydown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
actionHandler(selectedAction?.id);
|
||||
await tick();
|
||||
setTimeout(() => {
|
||||
const input = document.getElementById('floating-message-input');
|
||||
if (input) {
|
||||
input.focus();
|
||||
}
|
||||
}, 0);
|
||||
} else {
|
||||
actionHandler(action.id);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<div class="ml-1 mr-1">
|
||||
<button
|
||||
aria-label={$i18n.t('Submit question')}
|
||||
class="{floatingInputValue !== ''
|
||||
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
|
||||
: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 m-0.5 self-center"
|
||||
on:click={() => {
|
||||
actionHandler(selectedAction?.id);
|
||||
}}
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 16 16"
|
||||
fill="currentColor"
|
||||
class="size-4"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
>
|
||||
{#if action.icon}
|
||||
<svelte:component this={action.icon} className="size-3 shrink-0" />
|
||||
{/if}
|
||||
<div class="shrink-0">{action.label}</div>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<div
|
||||
class="bg-white dark:bg-gray-850 dark:text-gray-100 rounded-3xl shadow-xl w-80 max-w-full border border-gray-100 dark:border-gray-800"
|
||||
class="py-1 flex dark:text-gray-100 bg-white dark:bg-gray-850 border border-gray-100 dark:border-gray-800 w-72 rounded-full shadow-xl"
|
||||
>
|
||||
<div
|
||||
class="bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-3xl px-3.5 pt-3 w-full"
|
||||
>
|
||||
<div class="font-medium">
|
||||
<Markdown id={`${id}-float-prompt`} {content} />
|
||||
</div>
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
id="floating-message-input"
|
||||
class="ml-5 bg-transparent outline-hidden w-full flex-1 text-sm"
|
||||
placeholder={$i18n.t('Ask a question')}
|
||||
aria-label={$i18n.t('Ask a question')}
|
||||
bind:value={floatingInputValue}
|
||||
on:keydown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
actionHandler(selectedAction?.id);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<div class="bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-4xl w-full">
|
||||
<div
|
||||
class=" max-h-80 overflow-y-auto w-full markdown-prose-xs px-3.5 py-3"
|
||||
id="response-container"
|
||||
<div class="ml-1 mr-1">
|
||||
<button
|
||||
aria-label={$i18n.t('Submit question')}
|
||||
class="{floatingInputValue !== ''
|
||||
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
|
||||
: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 m-0.5 self-center"
|
||||
on:click={() => {
|
||||
actionHandler(selectedAction?.id);
|
||||
}}
|
||||
>
|
||||
{#if !responseContent || responseContent?.trim() === ''}
|
||||
<Skeleton size="sm" />
|
||||
{:else}
|
||||
<Markdown id={`${id}-float-response`} content={responseContent} />
|
||||
{/if}
|
||||
|
||||
{#if responseDone}
|
||||
<div class="flex justify-end pt-3 text-sm font-medium">
|
||||
<button
|
||||
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
|
||||
on:click={addHandler}
|
||||
>
|
||||
{$i18n.t('Add')}
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 16 16"
|
||||
fill="currentColor"
|
||||
class="size-4"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
export let onSave = (e) => {};
|
||||
export let onSourceClick = (e) => {};
|
||||
export let onTaskClick = (e) => {};
|
||||
export let onAddMessages = (e) => {};
|
||||
export let onSetInputText = (text) => {};
|
||||
|
||||
let contentContainerElement;
|
||||
let floatingButtonsElement;
|
||||
@@ -140,20 +140,36 @@
|
||||
}
|
||||
};
|
||||
|
||||
onMount(() => {
|
||||
if (floatingButtons) {
|
||||
contentContainerElement?.addEventListener('mouseup', updateButtonPosition);
|
||||
// Reactive listener attachment: re-attaches when floatingButtons
|
||||
// transitions from false → true (e.g. when message.done flips).
|
||||
let listenersAttached = false;
|
||||
|
||||
function attachListeners() {
|
||||
if (!listenersAttached && contentContainerElement) {
|
||||
contentContainerElement.addEventListener('mouseup', updateButtonPosition);
|
||||
document.addEventListener('mouseup', updateButtonPosition);
|
||||
document.addEventListener('keydown', keydownHandler);
|
||||
listenersAttached = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
if (floatingButtons) {
|
||||
function detachListeners() {
|
||||
if (listenersAttached) {
|
||||
contentContainerElement?.removeEventListener('mouseup', updateButtonPosition);
|
||||
document.removeEventListener('mouseup', updateButtonPosition);
|
||||
document.removeEventListener('keydown', keydownHandler);
|
||||
listenersAttached = false;
|
||||
}
|
||||
}
|
||||
|
||||
$: if (floatingButtons && contentContainerElement) {
|
||||
attachListeners();
|
||||
} else {
|
||||
detachListeners();
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
detachListeners();
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -201,17 +217,9 @@
|
||||
<FloatingButtons
|
||||
bind:this={floatingButtonsElement}
|
||||
{id}
|
||||
{messageId}
|
||||
actions={$settings?.floatingActionButtons ?? []}
|
||||
model={(selectedModels ?? []).includes(model?.id)
|
||||
? model?.id
|
||||
: (selectedModels ?? []).length > 0
|
||||
? selectedModels.at(0)
|
||||
: (model?.id ?? null)}
|
||||
messages={createMessagesList(history, messageId)}
|
||||
onAdd={({ modelId, parentId, messages }) => {
|
||||
console.log(modelId, parentId, messages);
|
||||
onAddMessages({ modelId, parentId, messages });
|
||||
onSetInputText={(text) => {
|
||||
onSetInputText(text);
|
||||
closeFloatingButtons();
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -788,9 +788,6 @@
|
||||
<!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content -->
|
||||
<ContentRenderer
|
||||
id={`${chatId}-${message.id}`}
|
||||
messageId={message.id}
|
||||
{history}
|
||||
{selectedModels}
|
||||
content={message.content}
|
||||
sources={message.sources}
|
||||
floatingButtons={message?.done &&
|
||||
@@ -814,8 +811,8 @@
|
||||
citationsElement?.showSourceModal(id);
|
||||
}
|
||||
}}
|
||||
onAddMessages={({ modelId, parentId, messages }) => {
|
||||
addMessages({ modelId, parentId, messages });
|
||||
onSetInputText={(text) => {
|
||||
setInputText(text);
|
||||
}}
|
||||
onSave={({ raw, oldContent, newContent }) => {
|
||||
history.messages[message.id].content = history.messages[
|
||||
|
||||
@@ -601,6 +601,8 @@
|
||||
src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${model.id}&lang=${$i18n.language}`}
|
||||
alt="modelfile profile"
|
||||
class=" rounded-2xl size-12 object-cover"
|
||||
loading="lazy"
|
||||
decoding="async"
|
||||
on:error={(e) => {
|
||||
e.target.src = '/favicon.png';
|
||||
}}
|
||||
|
||||
Reference in New Issue
Block a user