mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
refac: title generation
This commit is contained in:
200
backend/main.py
200
backend/main.py
@@ -53,6 +53,8 @@ from utils.utils import (
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from utils.task import title_generation_template
|
||||
|
||||
from apps.rag.utils import rag_messages
|
||||
|
||||
from config import (
|
||||
@@ -74,8 +76,9 @@ from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
WEBHOOK_URL,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
AppConfig,
|
||||
WEBUI_BUILD_HASH,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
AppConfig,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
@@ -131,7 +134,7 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
@@ -240,6 +243,78 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
app.add_middleware(RAGMiddleware)
|
||||
|
||||
|
||||
def filter_pipeline(payload, user):
|
||||
user = {"id": user.id, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if "pipeline" not in app.state.MODELS[model_id]:
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
if "title" in payload:
|
||||
del payload["title"]
|
||||
return payload
|
||||
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.method == "POST" and (
|
||||
@@ -255,85 +330,10 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = data["model"]
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
user = None
|
||||
if len(sorted_filters) > 0:
|
||||
try:
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(
|
||||
request.headers.get("Authorization")
|
||||
)
|
||||
)
|
||||
user = {"id": user.id, "name": user.name, "role": user.role}
|
||||
except:
|
||||
pass
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if "pipeline" not in app.state.MODELS[model_id]:
|
||||
if "chat_id" in data:
|
||||
del data["chat_id"]
|
||||
|
||||
if "title" in data:
|
||||
del data["title"]
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
)
|
||||
data = filter_pipeline(data, user)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
@@ -494,6 +494,44 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
@app.post("/api/title/completions")
|
||||
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_title")
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = title_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": 50,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"title": True,
|
||||
}
|
||||
|
||||
print(payload)
|
||||
payload = filter_pipeline(payload, user)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
else:
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||
model_id = form_data["model"]
|
||||
|
||||
Reference in New Issue
Block a user