mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 11:57:51 +01:00
Merge remote-tracking branch 'origin/dev' into feat/backend-web-search
This commit is contained in:
@@ -397,7 +397,7 @@ def generate_image(
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
|
||||
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
||||
@@ -75,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
|
||||
litellm_config = yaml.safe_load(file)
|
||||
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
|
||||
|
||||
|
||||
app.state.ENABLE = ENABLE_LITELLM
|
||||
app.state.CONFIG = litellm_config
|
||||
|
||||
@@ -151,10 +155,6 @@ async def shutdown_litellm_background():
|
||||
background_process = None
|
||||
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
@@ -64,8 +64,8 @@ app.add_middleware(
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
app.state.MODELS = {}
|
||||
@@ -124,8 +124,9 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user))
|
||||
|
||||
|
||||
async def fetch_url(url):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -177,11 +178,12 @@ async def get_ollama_tags(
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["models"] = list(
|
||||
filter(
|
||||
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["name"]
|
||||
in app.state.config.MODEL_FILTER_LIST,
|
||||
models["models"],
|
||||
)
|
||||
)
|
||||
@@ -1045,11 +1047,12 @@ async def get_openai_models(
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["models"] = list(
|
||||
filter(
|
||||
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["name"]
|
||||
in app.state.config.MODEL_FILTER_LIST,
|
||||
models["models"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from utils.utils import (
|
||||
)
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
ENABLE_OPENAI_API,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
CACHE_DIR,
|
||||
@@ -46,11 +47,14 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
||||
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||
|
||||
@@ -68,6 +72,21 @@ async def check_url(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(user=Depends(get_admin_user)):
|
||||
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
|
||||
|
||||
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
enable_openai_api: Optional[bool] = None
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
|
||||
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
|
||||
|
||||
|
||||
class UrlsUpdateForm(BaseModel):
|
||||
urls: List[str]
|
||||
|
||||
@@ -164,11 +183,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
async def fetch_url(url, key):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return await response.json()
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return await response.json()
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
log.error(f"Connection error: {e}")
|
||||
@@ -200,7 +223,7 @@ async def get_all_models():
|
||||
if (
|
||||
len(app.state.config.OPENAI_API_KEYS) == 1
|
||||
and app.state.config.OPENAI_API_KEYS[0] == ""
|
||||
):
|
||||
) or not app.state.config.ENABLE_OPENAI_API:
|
||||
models = {"data": []}
|
||||
else:
|
||||
tasks = [
|
||||
@@ -237,11 +260,11 @@ async def get_all_models():
|
||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
|
||||
if url_idx == None:
|
||||
models = await get_all_models()
|
||||
if app.state.ENABLE_MODEL_FILTER:
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models["data"] = list(
|
||||
filter(
|
||||
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
|
||||
lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
models["data"],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -70,6 +70,7 @@ from utils.misc import (
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
|
||||
from config import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
@@ -266,7 +267,7 @@ async def update_embedding_config(
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
@@ -439,12 +440,12 @@ async def update_query_settings(
|
||||
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.config.RAG_TEMPLATE = (
|
||||
form_data.template if form_data.template else RAG_TEMPLATE,
|
||||
form_data.template if form_data.template else RAG_TEMPLATE
|
||||
)
|
||||
app.state.config.TOP_K = form_data.k if form_data.k else 4
|
||||
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
|
||||
form_data.hybrid if form_data.hybrid else False,
|
||||
form_data.hybrid if form_data.hybrid else False
|
||||
)
|
||||
return {
|
||||
"status": True,
|
||||
@@ -1006,3 +1007,14 @@ def reset(user=Depends(get_admin_user)) -> bool:
|
||||
log.exception(e)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if ENV == "dev":
|
||||
|
||||
@app.get("/ef")
|
||||
async def get_embeddings():
|
||||
return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
|
||||
|
||||
@app.get("/ef/{text}")
|
||||
async def get_embeddings_text(text: str):
|
||||
return {"result": app.state.EMBEDDING_FUNCTION(text)}
|
||||
|
||||
53
backend/apps/web/internal/migrations/008_add_memory.py
Normal file
53
backend/apps/web/internal/migrations/008_add_memory.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
@migrator.create_model
|
||||
class Memory(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
content = pw.TextField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "memory"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("memory")
|
||||
@@ -9,6 +9,7 @@ from apps.web.routers import (
|
||||
modelfiles,
|
||||
prompts,
|
||||
configs,
|
||||
memories,
|
||||
utils,
|
||||
)
|
||||
from config import (
|
||||
@@ -41,6 +42,7 @@ app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
@@ -52,9 +54,12 @@ app.add_middleware(
|
||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
|
||||
app.include_router(documents.router, prefix="/documents", tags=["documents"])
|
||||
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
|
||||
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
|
||||
|
||||
app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
||||
|
||||
118
backend/apps/web/models/memories.py
Normal file
118
backend/apps/web/models/memories.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from pydantic import BaseModel
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from apps.web.internal.db import DB
|
||||
from apps.web.models.chats import Chats
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
####################
|
||||
# Memory DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Memory(Model):
|
||||
id = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
content = TextField()
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class MemoryModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
content: str
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class MemoriesTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.db.create_tables([Memory])
|
||||
|
||||
def insert_new_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
) -> Optional[MemoryModel]:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
memory = MemoryModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"content": content,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
result = Memory.create(**memory.model_dump())
|
||||
if result:
|
||||
return memory
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_memories(self) -> List[MemoryModel]:
|
||||
try:
|
||||
memories = Memory.select()
|
||||
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
|
||||
try:
|
||||
memories = Memory.select().where(Memory.user_id == user_id)
|
||||
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_memory_by_id(self, id) -> Optional[MemoryModel]:
|
||||
try:
|
||||
memory = Memory.get(Memory.id == id)
|
||||
return MemoryModel(**model_to_dict(memory))
|
||||
except:
|
||||
return None
|
||||
|
||||
def delete_memory_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Memory.delete().where(Memory.id == id)
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True
|
||||
|
||||
except:
|
||||
return False
|
||||
|
||||
def delete_memories_by_user_id(self, user_id: str) -> bool:
|
||||
try:
|
||||
query = Memory.delete().where(Memory.user_id == user_id)
|
||||
query.execute()
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
|
||||
query.execute()
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
Memories = MemoriesTable(DB)
|
||||
145
backend/apps/web/routers/memories.py
Normal file
145
backend/apps/web/routers/memories.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from fastapi import Response, Request
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from apps.web.models.memories import Memories, MemoryModel
|
||||
|
||||
from utils.utils import get_verified_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/ef")
|
||||
async def get_embeddings(request: Request):
|
||||
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
|
||||
|
||||
|
||||
############################
|
||||
# GetMemories
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[MemoryModel])
|
||||
async def get_memories(user=Depends(get_verified_user)):
|
||||
return Memories.get_memories_by_user_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
# AddMemory
|
||||
############################
|
||||
|
||||
|
||||
class AddMemoryForm(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/add", response_model=Optional[MemoryModel])
|
||||
async def add_memory(
|
||||
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
|
||||
):
|
||||
memory = Memories.insert_new_memory(user.id, form_data.content)
|
||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
|
||||
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
|
||||
collection.upsert(
|
||||
documents=[memory.content],
|
||||
ids=[memory.id],
|
||||
embeddings=[memory_embedding],
|
||||
metadatas=[{"created_at": memory.created_at}],
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
############################
|
||||
# QueryMemory
|
||||
############################
|
||||
|
||||
|
||||
class QueryMemoryForm(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/query")
|
||||
async def query_memory(
|
||||
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
|
||||
):
|
||||
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
|
||||
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=1, # how many results to return
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
############################
|
||||
# ResetMemoryFromVectorDB
|
||||
############################
|
||||
@router.get("/reset", response_model=bool)
|
||||
async def reset_memory_from_vector_db(
|
||||
request: Request, user=Depends(get_verified_user)
|
||||
):
|
||||
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
|
||||
|
||||
memories = Memories.get_memories_by_user_id(user.id)
|
||||
for memory in memories:
|
||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
|
||||
collection.upsert(
|
||||
documents=[memory.content],
|
||||
ids=[memory.id],
|
||||
embeddings=[memory_embedding],
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
############################
|
||||
# DeleteMemoriesByUserId
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/user", response_model=bool)
|
||||
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
|
||||
result = Memories.delete_memories_by_user_id(user.id)
|
||||
|
||||
if result:
|
||||
try:
|
||||
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
|
||||
except Exception as e:
|
||||
log.error(e)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
############################
|
||||
# DeleteMemoryById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{memory_id}", response_model=bool)
|
||||
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
|
||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
||||
|
||||
if result:
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(
|
||||
name=f"user-memory-{user.id}"
|
||||
)
|
||||
collection.delete(ids=[memory_id])
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -11,8 +11,9 @@ import logging
|
||||
|
||||
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
|
||||
from apps.web.models.auths import Auths
|
||||
from apps.web.models.chats import Chats
|
||||
|
||||
from utils.utils import get_current_user, get_password_hash, get_admin_user
|
||||
from utils.utils import get_verified_user, get_password_hash, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
@@ -67,6 +68,41 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserById
|
||||
############################
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
name: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
|
||||
if user_id.startswith("shared-"):
|
||||
chat_id = user_id.replace("shared-", "")
|
||||
chat = Chats.get_chat_by_id(chat_id)
|
||||
if chat:
|
||||
user_id = chat.user_id
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(user_id)
|
||||
|
||||
if user:
|
||||
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserById
|
||||
############################
|
||||
|
||||
@@ -417,6 +417,14 @@ OLLAMA_BASE_URLS = PersistentConfig(
|
||||
# OPENAI_API
|
||||
####################################
|
||||
|
||||
|
||||
ENABLE_OPENAI_API = PersistentConfig(
|
||||
"ENABLE_OPENAI_API",
|
||||
"openai.enable",
|
||||
os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
|
||||
|
||||
|
||||
@@ -118,6 +118,18 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
origins = ["*"]
|
||||
|
||||
|
||||
# Custom middleware to add security headers
|
||||
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
# async def dispatch(self, request: Request, call_next):
|
||||
# response: Response = await call_next(request)
|
||||
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
||||
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
||||
# return response
|
||||
|
||||
|
||||
# app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
class RAGMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
return_citations = False
|
||||
@@ -227,9 +239,15 @@ async def check_url(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
app.mount("/api/v1", webui_app)
|
||||
app.mount("/litellm/api", litellm_app)
|
||||
@app.middleware("http")
|
||||
async def update_embedding_function(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
if "/embedding/update" in request.url.path:
|
||||
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
||||
return response
|
||||
|
||||
|
||||
app.mount("/litellm/api", litellm_app)
|
||||
app.mount("/ollama", ollama_app)
|
||||
app.mount("/openai/api", openai_app)
|
||||
|
||||
@@ -237,6 +255,10 @@ app.mount("/images/api/v1", images_app)
|
||||
app.mount("/audio/api/v1", audio_app)
|
||||
app.mount("/rag/api/v1", rag_app)
|
||||
|
||||
app.mount("/api/v1", webui_app)
|
||||
|
||||
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
async def get_app_config():
|
||||
@@ -279,14 +301,14 @@ class ModelFilterConfigForm(BaseModel):
|
||||
async def update_model_filter_config(
|
||||
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
|
||||
app.state.config.MODEL_FILTER_LIST, form_data.models
|
||||
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
||||
app.state.config.MODEL_FILTER_LIST = form_data.models
|
||||
|
||||
ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
|
||||
openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
|
||||
litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
|
||||
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
|
||||
|
||||
Reference in New Issue
Block a user