This commit is contained in:
Timothy J. Baek
2024-09-10 04:37:06 +01:00
parent d5f13dd9e0
commit 522afbb0a0
7 changed files with 240 additions and 127 deletions

View File

@@ -50,16 +50,17 @@ async def add_memory(
user=Depends(get_verified_user),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[{"created_at": memory.created_at}],
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"metadata": {"created_at": memory.created_at},
}
],
)
return memory
@@ -79,14 +80,10 @@ class QueryMemoryForm(BaseModel):
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
results = collection.query(
query_embeddings=[query_embedding],
n_results=form_data.k, # how many results to return
results = VECTOR_DB_CLIENT.search(
name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
limit=form_data.k,
)
return results
@@ -100,18 +97,24 @@ async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
)
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
},
}
for memory in memories
],
)
return True
@@ -151,16 +154,18 @@ async def update_memory_by_id(
raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.upsert(
documents=[form_data.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[
{"created_at": memory.created_at, "updated_at": memory.updated_at}
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
},
}
],
)
@@ -177,10 +182,9 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
collection = VECTOR_DB_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
VECTOR_DB_CLIENT.delete(
collection_name=f"user-memory-{user.id}", ids=[memory_id]
)
collection.delete(ids=[memory_id])
return True
return False