This commit is contained in:
Timothy Jaeryang Baek
2026-04-24 18:20:10 +09:00
parent 7102a63c82
commit 3d1e355df7
14 changed files with 258 additions and 478 deletions

View File

@@ -1,9 +1,7 @@
import os
import json
import logging
import ssl as _stdlib_ssl
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from typing import Any, Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
@@ -38,31 +36,17 @@ from typing_extensions import Self
log = logging.getLogger(__name__)
@dataclass
class SSLParams:
"""SSL parameters extracted from a PostgreSQL ``DATABASE_URL``.
Holds the connection-mode flag and optional certificate file paths
so that each driver (asyncpg, psycopg2/libpq) can receive them in
the format it expects.
"""
mode: str | None = None
rootcert: str | None = None
cert: str | None = None
key: str | None = None
crl: str | None = None
def __bool__(self) -> bool:
return self.mode is not None
@property
def has_any(self) -> bool:
"""True when *any* SSL-related field is set (mode or cert files)."""
return any((self.mode, self.rootcert, self.cert, self.key, self.crl))
# ── URL extraction / reattachment ────────────────────────────────────
# ── SSL URL normalization (used by sync engine & Alembic migrations) ─
#
# psycopg2 (sync) needs ``sslmode=`` in the connection string (it does
# not recognise the bare ``ssl=`` key that some ORMs emit). The helpers
# below strip all SSL-related query params, normalise them, and
# reattach them in the canonical libpq form.
#
# The **async** engine now uses psycopg (v3), which speaks libpq
# natively, so it needs no translation at all — the DATABASE_URL is
# passed through as-is.
# ─────────────────────────────────────────────────────────────────────
def _pop_first(params: dict[str, list[str]], key: str) -> str | None:
@@ -71,63 +55,57 @@ def _pop_first(params: dict[str, list[str]], key: str) -> str | None:
return values[0] if values else None
def extract_ssl_params_from_url(url: str) -> tuple[str, SSLParams]:
"""Strip all SSL query-string parameters from a PostgreSQL URL.
asyncpg does not accept libpq-style certificate-file keys
(``sslrootcert``, ``sslcert``, ``sslkey``, ``sslcrl``), so every
SSL-related key is removed and returned as a structured
:class:`SSLParams` object.
Returns ``(url_without_ssl, ssl_params)``. Non-PostgreSQL URLs are
returned unchanged with an empty ``SSLParams``.
"""
if not url or not any(
def _is_postgres_url(url: str) -> bool:
"""Return True if *url* looks like a PostgreSQL connection string."""
return bool(url) and any(
url.startswith(p) for p in ('postgresql://', 'postgresql+', 'postgres://')
):
return url, SSLParams()
)
def extract_ssl_params_from_url(url: str) -> tuple[str, dict[str, str]]:
"""Strip SSL query-string parameters from a PostgreSQL URL.
Returns ``(url_without_ssl, ssl_dict)`` where *ssl_dict* maps
canonical libpq key names (``sslmode``, ``sslrootcert``, …) to
their values. Non-PostgreSQL URLs are returned unchanged with an
empty dict.
"""
if not _is_postgres_url(url):
return url, {}
parsed = urlparse(url)
qp = parse_qs(parsed.query, keep_blank_values=True)
# Prefer sslmode (libpq canonical) over the asyncpg-only ``ssl`` key.
# Both must be popped unconditionally so neither leaks into the cleaned URL.
# Prefer sslmode (libpq canonical) over the bare ``ssl`` key.
sslmode_val = _pop_first(qp, 'sslmode')
ssl_val = _pop_first(qp, 'ssl')
ssl_mode = sslmode_val or ssl_val
params = SSLParams(
mode=ssl_mode,
rootcert=_pop_first(qp, 'sslrootcert'),
cert=_pop_first(qp, 'sslcert'),
key=_pop_first(qp, 'sslkey'),
crl=_pop_first(qp, 'sslcrl'),
)
ssl_dict: dict[str, str] = {}
if ssl_mode:
ssl_dict['sslmode'] = ssl_mode
for key in ('sslrootcert', 'sslcert', 'sslkey', 'sslcrl'):
val = _pop_first(qp, key)
if val:
ssl_dict[key] = val
if not params.has_any:
return url, params
if not ssl_dict:
return url, ssl_dict
cleaned_query = urlencode(qp, doseq=True)
return urlunparse(parsed._replace(query=cleaned_query)), params
return urlunparse(parsed._replace(query=cleaned_query)), ssl_dict
def reattach_ssl_params_to_url(url_without_ssl: str, ssl_params: SSLParams) -> str:
def reattach_ssl_params_to_url(url_without_ssl: str, ssl_dict: dict[str, str]) -> str:
"""Re-append SSL query-string parameters to a cleaned PostgreSQL URL.
Used for psycopg2/libpq consumers that expect ``sslmode`` and the
certificate-file keys in the connection string.
"""
if not ssl_params:
if not ssl_dict:
return url_without_ssl
mapping = (
('sslmode', ssl_params.mode),
('sslrootcert', ssl_params.rootcert),
('sslcert', ssl_params.cert),
('sslkey', ssl_params.key),
('sslcrl', ssl_params.crl),
)
parts = [f'{k}={v}' for k, v in mapping if v]
parts = [f'{k}={v}' for k, v in ssl_dict.items() if v]
if not parts:
return url_without_ssl
@@ -135,54 +113,6 @@ def reattach_ssl_params_to_url(url_without_ssl: str, ssl_params: SSLParams) -> s
return f'{url_without_ssl}{sep}{"&".join(parts)}'
# ── asyncpg SSLContext builder ───────────────────────────────────────
def _make_ssl_context(ssl_params: SSLParams, *, verify: bool) -> _stdlib_ssl.SSLContext:
"""Create an :class:`ssl.SSLContext` from *ssl_params*.
When *verify* is ``False``, hostname checking and certificate
verification are disabled (matching libpq ``require`` semantics).
"""
ctx = _stdlib_ssl.create_default_context(cafile=ssl_params.rootcert)
if not verify:
ctx.check_hostname = False
ctx.verify_mode = _stdlib_ssl.CERT_NONE
if ssl_params.cert and ssl_params.key:
ctx.load_cert_chain(certfile=ssl_params.cert, keyfile=ssl_params.key)
if verify and ssl_params.crl:
ctx.load_verify_locations(cafile=ssl_params.crl)
ctx.verify_flags |= _stdlib_ssl.VERIFY_CRL_CHECK_LEAF
return ctx
def build_asyncpg_ssl_args(ssl_params: SSLParams) -> dict:
"""Convert :class:`SSLParams` to asyncpg-compatible ``connect_args``.
Returns a dict suitable for unpacking into
``create_async_engine(...)``.
"""
if not ssl_params:
return {}
mode = (ssl_params.mode or 'require').lower()
if mode == 'disable':
return {'connect_args': {'ssl': False}}
if mode in ('allow', 'prefer'):
return {}
if mode == 'require':
return {'connect_args': {'ssl': _make_ssl_context(ssl_params, verify=False)}}
if mode in ('verify-ca', 'verify-full'):
ctx = _make_ssl_context(ssl_params, verify=True)
if mode == 'verify-ca':
ctx.check_hostname = False
return {'connect_args': {'ssl': ctx}}
# Unknown value — pass through as-is and let asyncpg decide.
return {'connect_args': {'ssl': ssl_params.mode}}
# Backwards-compatible aliases for external callers.
extract_ssl_mode_from_url = extract_ssl_params_from_url
reattach_ssl_mode_to_url = reattach_ssl_params_to_url
@@ -245,32 +175,38 @@ if ENABLE_DB_MIGRATIONS:
handle_peewee_migration(DATABASE_URL)
# Normalize SSL params from the URL once; each engine branch re-injects
# the driver-appropriate form.
DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_PARAMS = extract_ssl_params_from_url(DATABASE_URL)
# Normalize SSL params from the URL once; the sync engine needs them
# reattached in canonical libpq form for psycopg2.
_url_without_ssl, _ssl_dict = extract_ssl_params_from_url(DATABASE_URL)
# For psycopg2 (sync engine), re-append sslmode + cert-file params.
SQLALCHEMY_DATABASE_URL = (
reattach_ssl_params_to_url(DATABASE_URL_WITHOUT_SSL, DATABASE_SSL_PARAMS) if DATABASE_SSL_PARAMS else DATABASE_URL
reattach_ssl_params_to_url(_url_without_ssl, _ssl_dict) if _ssl_dict else DATABASE_URL
)
def _make_async_url(url: str) -> str:
"""Convert a sync database URL to its async driver equivalent."""
"""Convert a sync database URL to its async driver equivalent.
The async engine uses psycopg (v3) which speaks libpq natively,
so all standard connection-string parameters (``sslmode``,
``options``, ``target_session_attrs``, etc.) are passed through
without any translation.
"""
if url.startswith('sqlite+sqlcipher://'):
# SQLCipher has no async driver — not supported for async
raise ValueError(
'sqlite+sqlcipher:// URLs are not supported with async engine. '
'Use standard sqlite:// or postgresql:// instead.'
)
if url.startswith('sqlite:///') or url.startswith('sqlite://'):
return url.replace('sqlite://', 'sqlite+aiosqlite://', 1)
# psycopg v3 — auto-selects async mode with create_async_engine
if url.startswith('postgresql+psycopg2://'):
return url.replace('postgresql+psycopg2://', 'postgresql+asyncpg://', 1)
return url.replace('postgresql+psycopg2://', 'postgresql+psycopg://', 1)
if url.startswith('postgresql://'):
return url.replace('postgresql://', 'postgresql+asyncpg://', 1)
return url.replace('postgresql://', 'postgresql+psycopg://', 1)
if url.startswith('postgres://'):
return url.replace('postgres://', 'postgresql+asyncpg://', 1)
return url.replace('postgres://', 'postgresql+psycopg://', 1)
# For other dialects, return as-is and let SQLAlchemy handle it
return url
@@ -395,10 +331,10 @@ get_db = contextmanager(get_session)
# ASYNC ENGINE (used for ALL runtime database operations)
# ============================================================
# Use the SSL-stripped URL for asyncpg — SSL is injected via connect_args.
ASYNC_SQLALCHEMY_DATABASE_URL = _make_async_url(
DATABASE_URL_WITHOUT_SSL if DATABASE_SSL_PARAMS else SQLALCHEMY_DATABASE_URL
)
# psycopg (v3) speaks libpq natively — the full DATABASE_URL is passed
# through as-is. SSL params, ``options``, ``target_session_attrs``, etc.
# all work without any stripping or translation.
ASYNC_SQLALCHEMY_DATABASE_URL = _make_async_url(SQLALCHEMY_DATABASE_URL)
if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL:
# Generous default — async coroutines + no session sharing = high connection demand.
@@ -416,10 +352,6 @@ if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL:
def _set_sqlite_pragmas(dbapi_connection, connection_record):
_apply_sqlite_pragmas(dbapi_connection)
else:
# Inject asyncpg-compatible SSL connect_args when the user specified
# sslmode/ssl in DATABASE_URL.
asyncpg_ssl_args = build_asyncpg_ssl_args(DATABASE_SSL_PARAMS)
if isinstance(DATABASE_POOL_SIZE, int):
if DATABASE_POOL_SIZE > 0:
async_engine = create_async_engine(
@@ -429,20 +361,17 @@ else:
pool_timeout=DATABASE_POOL_TIMEOUT,
pool_recycle=DATABASE_POOL_RECYCLE,
pool_pre_ping=True,
**asyncpg_ssl_args,
)
else:
async_engine = create_async_engine(
ASYNC_SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True,
poolclass=NullPool,
**asyncpg_ssl_args,
)
else:
async_engine = create_async_engine(
ASYNC_SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True,
**asyncpg_ssl_args,
)

View File

@@ -37,7 +37,7 @@ target_metadata = Auth.metadata
DB_URL = DATABASE_URL
# Normalize SSL query params for psycopg2 (Alembic uses psycopg2, not asyncpg).
# Normalize SSL query params for psycopg2 (Alembic uses psycopg2 for sync migrations).
url_without_ssl, ssl_params = extract_ssl_params_from_url(DB_URL)
DB_URL = reattach_ssl_params_to_url(url_without_ssl, ssl_params) if ssl_params else DB_URL

View File

@@ -391,6 +391,56 @@ class ModelsTable:
return ModelListResponse(items=models, total=total)
async def get_model_meta_by_id(
self, id: str, db: Optional[AsyncSession] = None
) -> Optional[tuple[dict, int]]:
"""Return (meta, updated_at) for a model, skipping access grant resolution."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(
select(Model.meta, Model.updated_at).filter_by(id=id)
)
return result.first()
except Exception:
return None
async def get_all_tags(
self,
user_id: str,
is_admin: bool = False,
db: Optional[AsyncSession] = None,
) -> set[str]:
"""Extract unique tag names from model meta, querying only the meta column."""
async with get_async_db_context(db) as db:
stmt = select(Model.meta).filter(Model.base_model_id != None)
if not is_admin:
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
user_group_ids = [group.id for group in user_groups]
filter_dict = {'user_id': user_id}
if user_group_ids:
filter_dict['group_ids'] = user_group_ids
stmt = self._has_permission(db, stmt, filter_dict, permission='read')
result = await db.execute(stmt)
rows = result.scalars().all()
tags_set: set[str] = set()
for meta in rows:
if not meta:
continue
for tag in meta.get('tags', []):
try:
name = tag.get('name') if isinstance(tag, dict) else str(tag)
if name:
tags_set.add(name)
except Exception:
continue
return tags_set
async def get_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
try:
async with get_async_db_context(db) as db:

View File

@@ -138,18 +138,25 @@ async def get_models(
db=db,
)
return ModelAccessListResponse(
items=[
# Strip profile_image_url from meta — images are served via /model/profile/image.
items = []
for model in result.items:
data = model.model_dump()
if data.get('meta'):
data['meta'].pop('profile_image_url', None)
items.append(
ModelAccessResponse(
**model.model_dump(),
**data,
write_access=(
(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == model.user_id
or model.id in writable_model_ids
),
)
for model in result.items
],
)
return ModelAccessListResponse(
items=items,
total=result.total,
)
@@ -171,25 +178,12 @@ async def get_base_models(user=Depends(get_admin_user), db: AsyncSession = Depen
@router.get('/tags', response_model=list[str])
async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL:
models = await Models.get_models(db=db)
else:
models = await Models.get_models_by_user_id(user.id, db=db)
tags_set = set()
for model in models:
if model.meta:
meta = model.meta.model_dump()
for tag in meta.get('tags', []):
try:
name = tag.get('name') if isinstance(tag, dict) else str(tag)
if name:
tags_set.add(name)
except Exception:
continue
tags = sorted(tags_set)
return tags
tags = await Models.get_all_tags(
user_id=user.id,
is_admin=(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL),
db=db,
)
return sorted(tags)
############################
@@ -466,54 +460,48 @@ async def get_model_profile_image(
user=Depends(get_verified_user),
db: AsyncSession = Depends(get_async_session),
):
model = await Models.get_model_by_id(id, db=db)
model_meta = await Models.get_model_meta_by_id(id, db=db)
if model:
etag = f'"{model.updated_at}"' if model.updated_at else None
if model_meta:
meta, updated_at = model_meta
profile_image_url = (meta or {}).get('profile_image_url')
if model.meta.profile_image_url:
if model.meta.profile_image_url.startswith('http'):
if profile_image_url:
if profile_image_url.startswith('http'):
return Response(
status_code=status.HTTP_302_FOUND,
headers={'Location': model.meta.profile_image_url},
headers={'Location': profile_image_url},
)
elif model.meta.profile_image_url.startswith('data:image'):
elif profile_image_url.startswith('data:image'):
try:
header, base64_data = model.meta.profile_image_url.split(',', 1)
header, base64_data = profile_image_url.split(',', 1)
image_data = base64.b64decode(base64_data)
image_buffer = io.BytesIO(image_data)
media_type = header.split(';')[0].lstrip('data:')
headers = {'Content-Disposition': 'inline'}
if etag:
headers['ETag'] = etag
if updated_at:
headers['ETag'] = f'"{updated_at}"'
return StreamingResponse(
image_buffer,
media_type=media_type,
headers=headers,
)
except Exception as e:
except Exception:
pass
else:
safe_static = _safe_static_redirect_path(model.meta.profile_image_url)
safe_static = _safe_static_redirect_path(profile_image_url)
if safe_static:
return RedirectResponse(
url=safe_static,
status_code=status.HTTP_302_FOUND,
)
# Canonical URL so browsers cache one asset for all default model avatars
# (distinct /profile/image?id=... URLs would otherwise re-download the same bytes).
return RedirectResponse(
url='/static/favicon.png',
status_code=status.HTTP_302_FOUND,
)
else:
return RedirectResponse(
url='/static/favicon.png',
status_code=status.HTTP_302_FOUND,
)
return RedirectResponse(
url='/static/favicon.png',
status_code=status.HTTP_302_FOUND,
)
############################

View File

@@ -14,7 +14,7 @@ async def get_function_module(request, function_id, load_from_db=True):
"""
Get the function module by its ID.
"""
function_module, _, _ = await get_function_module_from_cache(request, function_id, load_from_db)
function_module, _, _ = await get_function_module_from_cache(request, function_id, load_from_db=load_from_db)
return function_module

View File

@@ -287,9 +287,9 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
# imported/custom model configs may reference tools or filters the user
# hasn't installed, and trying to load those would cause persistent
# "Failed to load function module" log spam on every model refresh.
for function_id in functions_by_id:
for function_id, function in functions_by_id.items():
try:
await get_function_module_from_cache(request, function_id)
await get_function_module_from_cache(request, function_id, function=function)
except Exception as e:
log.debug(f'Failed to load function module for {function_id}: {e}')

View File

@@ -14,7 +14,7 @@ from open_webui.env import (
OFFLINE_MODE,
ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS,
)
from open_webui.models.functions import Functions
from open_webui.models.functions import FunctionModel, Functions
from open_webui.models.tools import Tools
log = logging.getLogger(__name__)
@@ -335,13 +335,14 @@ async def get_tool_module_from_cache(request, tool_id, load_from_db=True):
return tool_module, frontmatter
async def get_function_module_from_cache(request, function_id, load_from_db=True):
async def get_function_module_from_cache(request, function_id, function: FunctionModel | None = None, load_from_db=True):
if load_from_db:
# Always load from the database by default
# This is useful for hooks like "inlet" or "outlet" where the content might change
# and we want to ensure the latest content is used.
function = await Functions.get_function_by_id(function_id)
if function is None:
function = await Functions.get_function_by_id(function_id)
if not function:
raise Exception(f'Function not found: {function_id}')
content = function.content

View File

@@ -28,7 +28,7 @@ starsessions[redis]==2.2.1
sqlalchemy==2.0.48
aiosqlite==0.21.0
asyncpg==0.30.0
psycopg[binary]==3.2.9
alembic==1.18.4
peewee==3.19.0
peewee-migrate==1.14.3

View File

@@ -26,7 +26,7 @@ python-mimeparse==2.0.0
sqlalchemy[asyncio]==2.0.48
aiosqlite==0.21.0
asyncpg==0.30.0
psycopg[binary]==3.2.9
alembic==1.18.4
peewee==3.19.0
peewee-migrate==1.14.3

View File

@@ -34,7 +34,7 @@ dependencies = [
"sqlalchemy[asyncio]==2.0.48",
"aiosqlite==0.21.0",
"asyncpg==0.30.0",
"psycopg[binary]==3.2.9",
"alembic==1.18.4",
"peewee==3.19.0",
"peewee-migrate==1.14.3",

View File

@@ -1,27 +1,14 @@
<script lang="ts">
import { toast } from 'svelte-sonner';
import DOMPurify from 'dompurify';
import { marked } from 'marked';
import { getContext, tick, onDestroy } from 'svelte';
import { getContext, tick } from 'svelte';
const i18n = getContext('i18n');
import { chatCompletion } from '$lib/apis/openai';
import ChatBubble from '$lib/components/icons/ChatBubble.svelte';
import LightBulb from '$lib/components/icons/LightBulb.svelte';
import Markdown from '../Messages/Markdown.svelte';
import Skeleton from '../Messages/Skeleton.svelte';
import { chatId, models, socket } from '$lib/stores';
export let id = '';
export let messageId = '';
export let model = null;
export let messages = [];
export let actions = [];
export let onAdd = (e) => {};
export let onSetInputText = (text) => {};
let floatingInput = false;
let selectedAction = null;
@@ -29,11 +16,6 @@
let selectedText = '';
let floatingInputValue = '';
let content = '';
let responseContent = null;
let responseDone = false;
let controller = null;
$: if (actions.length === 0) {
actions = DEFAULT_ACTIONS;
}
@@ -54,25 +36,7 @@
}
];
const autoScroll = async () => {
const responseContainer = document.getElementById('response-container');
if (responseContainer) {
// Scroll to bottom only if the scroll is at the bottom give 50px buffer
if (
responseContainer.scrollHeight - responseContainer.clientHeight <=
responseContainer.scrollTop + 50
) {
responseContainer.scrollTop = responseContainer.scrollHeight;
}
}
};
const actionHandler = async (actionId) => {
if (!model) {
toast.error($i18n.t('Model not selected'));
return;
}
const actionHandler = (actionId) => {
let selectedContent = selectedText
.split('\n')
.map((line) => `> ${line}`)
@@ -80,27 +44,20 @@
let selectedAction = actions.find((action) => action.id === actionId);
if (!selectedAction) {
toast.error($i18n.t('Action not found'));
return;
}
let prompt = selectedAction?.prompt ?? '';
let toolIds = [];
// Handle: {{variableId|tool:id="toolId"}} pattern
// This regex captures variableId and toolId from {{variableId|tool:id="toolId"}}
const varToolPattern = /\{\{(.*?)\|tool:id="([^"]+)"\}\}/g;
prompt = prompt.replace(varToolPattern, (match, variableId, toolId) => {
toolIds.push(toolId);
return variableId; // Replace with just variableId
});
// legacy {{TOOL:toolId}} pattern (for backward compatibility)
let toolIdPattern = /\{\{TOOL:([^\}]+)\}\}/g;
let match;
while ((match = toolIdPattern.exec(prompt)) !== null) {
toolIds.push(match[1]);
}
// Remove all TOOL placeholders from the prompt
prompt = prompt.replace(toolIdPattern, '');
@@ -113,131 +70,17 @@
prompt = prompt.replace('{{CONTENT}}', selectedText);
prompt = prompt.replace('{{SELECTED_CONTENT}}', selectedContent);
content = prompt;
responseContent = '';
let res;
[res, controller] = await chatCompletion(localStorage.token, {
model: model,
model_item: $models.find((m) => m.id === model),
messages: [
...messages,
{
role: 'user',
content: content
}
].map((message) => ({
role: message.role,
content: message.content
})),
...(toolIds.length > 0
? {
tool_ids: toolIds
// params: {
// function_calling: 'native'
// }
}
: {}),
stream: true // Enable streaming
});
if (res && res.ok) {
const reader = res.body.getReader();
const decoder = new TextDecoder();
const processStream = async () => {
while (true) {
// Read data chunks from the response stream
const { done, value } = await reader.read();
if (done) {
break;
}
// Decode the received chunk
const chunk = decoder.decode(value, { stream: true });
// Process lines within the chunk
const lines = chunk.split('\n').filter((line) => line.trim() !== '');
for (const line of lines) {
if (line.startsWith('data: ')) {
if (line.startsWith('data: [DONE]')) {
responseDone = true;
await tick();
autoScroll();
continue;
} else {
// Parse the JSON chunk
try {
const data = JSON.parse(line.slice(6));
// Append the `content` field from the "choices" object
if (data.choices && data.choices[0]?.delta?.content) {
responseContent += data.choices[0].delta.content;
autoScroll();
}
} catch (e) {
console.error(e);
}
}
}
}
}
};
// Process the stream in the background
try {
await processStream();
} catch (e) {
if (e.name !== 'AbortError') {
console.error(e);
}
}
} else {
toast.error($i18n.t('An error occurred while fetching the explanation'));
}
};
const addHandler = async () => {
const messages = [
{
role: 'user',
content: content
},
{
role: 'assistant',
content: responseContent
}
];
onAdd({
modelId: model,
parentId: messageId,
messages: messages
});
// Prepopulate the main chat input instead of inline streaming
onSetInputText(prompt);
closeHandler();
};
export const closeHandler = () => {
if (controller) {
controller.abort();
}
selectedAction = null;
selectedText = '';
responseContent = null;
responseDone = false;
floatingInput = false;
floatingInputValue = '';
};
onDestroy(() => {
if (controller) {
controller.abort();
}
});
</script>
<div
@@ -245,120 +88,82 @@
class="absolute rounded-lg mt-1 text-xs z-9999"
style="display: none"
>
{#if responseContent === null}
{#if !floatingInput}
<div
class="flex flex-row shrink-0 p-0.5 bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-xl shadow-xl border border-gray-100 dark:border-gray-800"
>
{#each actions as action}
<button
aria-label={action.label}
class="px-1.5 py-[1px] hover:bg-gray-50 dark:hover:bg-gray-800 rounded-xl flex items-center gap-1 min-w-fit transition"
on:click={async () => {
selectedText = window.getSelection().toString();
selectedAction = action;
{#if !floatingInput}
<div
class="flex flex-row shrink-0 p-0.5 bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-xl shadow-xl border border-gray-100 dark:border-gray-800"
>
{#each actions as action}
<button
aria-label={action.label}
class="px-1.5 py-[1px] hover:bg-gray-50 dark:hover:bg-gray-800 rounded-xl flex items-center gap-1 min-w-fit transition"
on:click={async () => {
selectedText = window.getSelection().toString();
selectedAction = action;
if (action.prompt.includes('{{INPUT_CONTENT}}')) {
floatingInput = true;
floatingInputValue = '';
if (action.prompt.includes('{{INPUT_CONTENT}}')) {
floatingInput = true;
floatingInputValue = '';
await tick();
setTimeout(() => {
const input = document.getElementById('floating-message-input');
if (input) {
input.focus();
}
}, 0);
} else {
actionHandler(action.id);
}
}}
>
{#if action.icon}
<svelte:component this={action.icon} className="size-3 shrink-0" />
{/if}
<div class="shrink-0">{action.label}</div>
</button>
{/each}
</div>
{:else}
<div
class="py-1 flex dark:text-gray-100 bg-white dark:bg-gray-850 border border-gray-100 dark:border-gray-800 w-72 rounded-full shadow-xl"
>
<input
type="text"
id="floating-message-input"
class="ml-5 bg-transparent outline-hidden w-full flex-1 text-sm"
placeholder={$i18n.t('Ask a question')}
aria-label={$i18n.t('Ask a question')}
bind:value={floatingInputValue}
on:keydown={(e) => {
if (e.key === 'Enter') {
actionHandler(selectedAction?.id);
await tick();
setTimeout(() => {
const input = document.getElementById('floating-message-input');
if (input) {
input.focus();
}
}, 0);
} else {
actionHandler(action.id);
}
}}
/>
<div class="ml-1 mr-1">
<button
aria-label={$i18n.t('Submit question')}
class="{floatingInputValue !== ''
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 m-0.5 self-center"
on:click={() => {
actionHandler(selectedAction?.id);
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="size-4"
>
<path
fill-rule="evenodd"
d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
clip-rule="evenodd"
/>
</svg>
</button>
</div>
</div>
{/if}
>
{#if action.icon}
<svelte:component this={action.icon} className="size-3 shrink-0" />
{/if}
<div class="shrink-0">{action.label}</div>
</button>
{/each}
</div>
{:else}
<div
class="bg-white dark:bg-gray-850 dark:text-gray-100 rounded-3xl shadow-xl w-80 max-w-full border border-gray-100 dark:border-gray-800"
class="py-1 flex dark:text-gray-100 bg-white dark:bg-gray-850 border border-gray-100 dark:border-gray-800 w-72 rounded-full shadow-xl"
>
<div
class="bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-3xl px-3.5 pt-3 w-full"
>
<div class="font-medium">
<Markdown id={`${id}-float-prompt`} {content} />
</div>
</div>
<input
type="text"
id="floating-message-input"
class="ml-5 bg-transparent outline-hidden w-full flex-1 text-sm"
placeholder={$i18n.t('Ask a question')}
aria-label={$i18n.t('Ask a question')}
bind:value={floatingInputValue}
on:keydown={(e) => {
if (e.key === 'Enter') {
actionHandler(selectedAction?.id);
}
}}
/>
<div class="bg-white dark:bg-gray-850 dark:text-gray-100 text-medium rounded-4xl w-full">
<div
class=" max-h-80 overflow-y-auto w-full markdown-prose-xs px-3.5 py-3"
id="response-container"
<div class="ml-1 mr-1">
<button
aria-label={$i18n.t('Submit question')}
class="{floatingInputValue !== ''
? 'bg-black text-white hover:bg-gray-900 dark:bg-white dark:text-black dark:hover:bg-gray-100 '
: 'text-white bg-gray-200 dark:text-gray-900 dark:bg-gray-700 disabled'} transition rounded-full p-1.5 m-0.5 self-center"
on:click={() => {
actionHandler(selectedAction?.id);
}}
>
{#if !responseContent || responseContent?.trim() === ''}
<Skeleton size="sm" />
{:else}
<Markdown id={`${id}-float-response`} content={responseContent} />
{/if}
{#if responseDone}
<div class="flex justify-end pt-3 text-sm font-medium">
<button
class="px-3.5 py-1.5 text-sm font-medium bg-black hover:bg-gray-900 text-white dark:bg-white dark:text-black dark:hover:bg-gray-100 transition rounded-full"
on:click={addHandler}
>
{$i18n.t('Add')}
</button>
</div>
{/if}
</div>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="size-4"
>
<path
fill-rule="evenodd"
d="M8 14a.75.75 0 0 1-.75-.75V4.56L4.03 7.78a.75.75 0 0 1-1.06-1.06l4.5-4.5a.75.75 0 0 1 1.06 0l4.5 4.5a.75.75 0 0 1-1.06 1.06L8.75 4.56v8.69A.75.75 0 0 1 8 14Z"
clip-rule="evenodd"
/>
</svg>
</button>
</div>
</div>
{/if}

View File

@@ -37,7 +37,7 @@
export let onSave = (e) => {};
export let onSourceClick = (e) => {};
export let onTaskClick = (e) => {};
export let onAddMessages = (e) => {};
export let onSetInputText = (text) => {};
let contentContainerElement;
let floatingButtonsElement;
@@ -140,20 +140,36 @@
}
};
onMount(() => {
if (floatingButtons) {
contentContainerElement?.addEventListener('mouseup', updateButtonPosition);
// Reactive listener attachment: re-attaches when floatingButtons
// transitions from false → true (e.g. when message.done flips).
let listenersAttached = false;
function attachListeners() {
if (!listenersAttached && contentContainerElement) {
contentContainerElement.addEventListener('mouseup', updateButtonPosition);
document.addEventListener('mouseup', updateButtonPosition);
document.addEventListener('keydown', keydownHandler);
listenersAttached = true;
}
});
}
onDestroy(() => {
if (floatingButtons) {
function detachListeners() {
if (listenersAttached) {
contentContainerElement?.removeEventListener('mouseup', updateButtonPosition);
document.removeEventListener('mouseup', updateButtonPosition);
document.removeEventListener('keydown', keydownHandler);
listenersAttached = false;
}
}
$: if (floatingButtons && contentContainerElement) {
attachListeners();
} else {
detachListeners();
}
onDestroy(() => {
detachListeners();
});
</script>
@@ -201,17 +217,9 @@
<FloatingButtons
bind:this={floatingButtonsElement}
{id}
{messageId}
actions={$settings?.floatingActionButtons ?? []}
model={(selectedModels ?? []).includes(model?.id)
? model?.id
: (selectedModels ?? []).length > 0
? selectedModels.at(0)
: (model?.id ?? null)}
messages={createMessagesList(history, messageId)}
onAdd={({ modelId, parentId, messages }) => {
console.log(modelId, parentId, messages);
onAddMessages({ modelId, parentId, messages });
onSetInputText={(text) => {
onSetInputText(text);
closeFloatingButtons();
}}
/>

View File

@@ -788,9 +788,6 @@
<!-- unless message.error === true which is legacy error handling, where the error message is stored in message.content -->
<ContentRenderer
id={`${chatId}-${message.id}`}
messageId={message.id}
{history}
{selectedModels}
content={message.content}
sources={message.sources}
floatingButtons={message?.done &&
@@ -814,8 +811,8 @@
citationsElement?.showSourceModal(id);
}
}}
onAddMessages={({ modelId, parentId, messages }) => {
addMessages({ modelId, parentId, messages });
onSetInputText={(text) => {
setInputText(text);
}}
onSave={({ raw, oldContent, newContent }) => {
history.messages[message.id].content = history.messages[

View File

@@ -601,6 +601,8 @@
src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${model.id}&lang=${$i18n.language}`}
alt="modelfile profile"
class=" rounded-2xl size-12 object-cover"
loading="lazy"
decoding="async"
on:error={(e) => {
e.target.src = '/favicon.png';
}}