This commit is contained in:
Timothy J. Baek
2024-09-10 02:27:50 +01:00
parent 1023ff8454
commit 4354f270ce
7 changed files with 138 additions and 62 deletions

View File

@@ -1,12 +1,13 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
import logging
from typing import Optional
from open_webui.apps.webui.models.memories import Memories, MemoryModel
from open_webui.config import CHROMA_CLIENT
from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.utils import get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -51,7 +52,9 @@ async def add_memory(
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
@@ -77,7 +80,9 @@ async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
results = collection.query(
query_embeddings=[query_embedding],
@@ -94,8 +99,10 @@ async def query_memory(
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
@@ -119,7 +126,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
if result:
try:
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
except Exception as e:
log.error(e)
return True
@@ -145,7 +152,7 @@ async def update_memory_by_id(
if form_data.content is not None:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
@@ -170,7 +177,7 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
collection = CHROMA_CLIENT.get_or_create_collection(
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.delete(ids=[memory_id])