mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
@@ -405,14 +405,19 @@ async def generate_chat_completion(
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Change max_completion_tokens to max_tokens (Backward compatible)
|
||||
if "api.openai.com" not in url and not payload["model"].lower().startswith("o1-"):
|
||||
if "max_completion_tokens" in payload:
|
||||
payload["max_tokens"] = payload.pop("max_completion_tokens")
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
log.debug(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
@@ -1099,35 +1099,35 @@ def store_docs_in_vector_db(
|
||||
log.info(f"deleting existing collection {collection_name}")
|
||||
VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
|
||||
|
||||
embedding_function = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
VECTOR_DB_CLIENT.insert(
|
||||
collection_name=collection_name,
|
||||
items=[
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"vector": embedding_function(text.replace("\n", " ")),
|
||||
"metadata": metadatas[idx],
|
||||
}
|
||||
for idx, text in enumerate(texts)
|
||||
],
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "UniqueConstraintError":
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
|
||||
log.info(f"collection {collection_name} already exists")
|
||||
return True
|
||||
else:
|
||||
embedding_function = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
VECTOR_DB_CLIENT.insert(
|
||||
collection_name=collection_name,
|
||||
items=[
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"vector": embedding_function(text.replace("\n", " ")),
|
||||
"metadata": metadatas[idx],
|
||||
}
|
||||
for idx, text in enumerate(texts)
|
||||
],
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -2,16 +2,38 @@ import asyncio
|
||||
|
||||
import socketio
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
from open_webui.env import ENABLE_WEBSOCKET_SUPPORT
|
||||
from open_webui.env import (
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
)
|
||||
from open_webui.utils.utils import decode_token
|
||||
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
)
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
client_manager=mgr,
|
||||
)
|
||||
else:
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
)
|
||||
|
||||
|
||||
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
|
||||
|
||||
# Dictionary to maintain the user pool
|
||||
|
||||
@@ -302,3 +302,7 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
||||
ENABLE_WEBSOCKET_SUPPORT = (
|
||||
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
@@ -19,7 +19,9 @@ from open_webui.apps.audio.main import app as audio_app
|
||||
from open_webui.apps.images.main import app as images_app
|
||||
from open_webui.apps.ollama.main import app as ollama_app
|
||||
from open_webui.apps.ollama.main import (
|
||||
generate_openai_chat_completion as generate_ollama_chat_completion,
|
||||
GenerateChatCompletionForm,
|
||||
generate_chat_completion as generate_ollama_chat_completion,
|
||||
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
|
||||
)
|
||||
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
|
||||
from open_webui.apps.openai.main import app as openai_app
|
||||
@@ -135,6 +137,12 @@ from open_webui.utils.utils import (
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||
from open_webui.utils.response import (
|
||||
convert_response_ollama_to_openai,
|
||||
convert_streaming_response_ollama_to_openai,
|
||||
)
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
@@ -1048,7 +1056,18 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
||||
if model.get("pipe"):
|
||||
return await generate_function_chat_completion(form_data, user=user)
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(form_data, user=user)
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
form_data = GenerateChatCompletionForm(**form_data)
|
||||
response = await generate_ollama_chat_completion(form_data=form_data, user=user)
|
||||
if form_data.stream:
|
||||
response.headers["content-type"] = "text/event-stream"
|
||||
return StreamingResponse(
|
||||
convert_streaming_response_ollama_to_openai(response),
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
else:
|
||||
return convert_response_ollama_to_openai(response)
|
||||
else:
|
||||
return await generate_openai_chat_completion(form_data, user=user)
|
||||
|
||||
@@ -1399,9 +1418,10 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
@@ -1440,9 +1460,9 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.TITLE_GENERATION)},
|
||||
}
|
||||
|
||||
log.debug(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
@@ -1456,7 +1476,6 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
@@ -1484,6 +1503,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
@@ -1516,9 +1537,9 @@ Search Query:"""
|
||||
),
|
||||
"metadata": {"task": str(TASKS.QUERY_GENERATION)},
|
||||
}
|
||||
log.debug(payload)
|
||||
|
||||
print(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
@@ -1532,7 +1553,6 @@ Search Query:"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
@@ -1555,12 +1575,13 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
template = '''
|
||||
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
|
||||
Message: """{{prompt}}"""
|
||||
'''
|
||||
|
||||
content = title_generation_template(
|
||||
template,
|
||||
form_data["prompt"],
|
||||
@@ -1584,9 +1605,9 @@ Message: """{{prompt}}"""
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.EMOJI_GENERATION)},
|
||||
}
|
||||
|
||||
log.debug(payload)
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
@@ -1600,7 +1621,6 @@ Message: """{{prompt}}"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
@@ -1620,8 +1640,10 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
model_id = get_task_model_id(model_id)
|
||||
print(model_id)
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
||||
|
||||
@@ -1636,13 +1658,12 @@ Responses from models: {{responses}}"""
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": form_data.get("stream", False),
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"metadata": {"task": str(TASKS.MOA_RESPONSE_GENERATION)},
|
||||
}
|
||||
|
||||
log.debug(payload)
|
||||
|
||||
try:
|
||||
@@ -1658,7 +1679,6 @@ Responses from models: {{responses}}"""
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
|
||||
@@ -105,17 +105,25 @@ def openai_chat_message_template(model: str):
|
||||
}
|
||||
|
||||
|
||||
def openai_chat_chunk_message_template(model: str, message: str) -> dict:
|
||||
def openai_chat_chunk_message_template(
|
||||
model: str, message: Optional[str] = None
|
||||
) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion.chunk"
|
||||
template["choices"][0]["delta"] = {"content": message}
|
||||
if message:
|
||||
template["choices"][0]["delta"] = {"content": message}
|
||||
else:
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
return template
|
||||
|
||||
|
||||
def openai_chat_completion_message_template(model: str, message: str) -> dict:
|
||||
def openai_chat_completion_message_template(
|
||||
model: str, message: Optional[str] = None
|
||||
) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion"
|
||||
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
|
||||
if message:
|
||||
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
return template
|
||||
|
||||
|
||||
@@ -86,3 +86,49 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
form_data[value] = param
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
"""
|
||||
Converts a payload formatted for OpenAI's API to be compatible with Ollama's API endpoint for chat completions.
|
||||
|
||||
Args:
|
||||
openai_payload (dict): The payload originally designed for OpenAI API usage.
|
||||
|
||||
Returns:
|
||||
dict: A modified payload compatible with the Ollama API.
|
||||
"""
|
||||
ollama_payload = {}
|
||||
|
||||
# Mapping basic model and message details
|
||||
ollama_payload["model"] = openai_payload.get("model")
|
||||
ollama_payload["messages"] = openai_payload.get("messages")
|
||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||
|
||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||
ollama_options = {}
|
||||
|
||||
# Handle parameters which map directly
|
||||
for param in ["temperature", "top_p", "seed"]:
|
||||
if param in openai_payload:
|
||||
ollama_options[param] = openai_payload[param]
|
||||
|
||||
# Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_completion_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_completion_tokens"]
|
||||
elif "max_tokens" in openai_payload:
|
||||
ollama_options["num_predict"] = openai_payload["max_tokens"]
|
||||
|
||||
# Handle frequency / presence_penalty, which needs renaming and checking
|
||||
if "frequency_penalty" in openai_payload:
|
||||
ollama_options["repeat_penalty"] = openai_payload["frequency_penalty"]
|
||||
|
||||
if "presence_penalty" in openai_payload and "penalty" not in ollama_options:
|
||||
# We are assuming presence penalty uses a similar concept in Ollama, which needs custom handling if exists.
|
||||
ollama_options["new_topic_penalty"] = openai_payload["presence_penalty"]
|
||||
|
||||
# Add options to payload if any have been set
|
||||
if ollama_options:
|
||||
ollama_payload["options"] = ollama_options
|
||||
|
||||
return ollama_payload
|
||||
|
||||
32
backend/open_webui/utils/response.py
Normal file
32
backend/open_webui/utils/response.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import json
|
||||
from open_webui.utils.misc import (
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
)
|
||||
|
||||
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
|
||||
response = openai_chat_completion_message_template(model, message_content)
|
||||
return response
|
||||
|
||||
|
||||
async def convert_streaming_response_ollama_to_openai(ollama_streaming_response):
|
||||
async for data in ollama_streaming_response.body_iterator:
|
||||
data = json.loads(data)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
message_content = data.get("message", {}).get("content", "")
|
||||
done = data.get("done", False)
|
||||
|
||||
data = openai_chat_chunk_message_template(
|
||||
model, message_content if not done else None
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
if done:
|
||||
line += "data: [DONE]\n\n"
|
||||
|
||||
yield line
|
||||
Reference in New Issue
Block a user