mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 03:47:49 +01:00
refactor: Update GenerateEmbeddingsForm to support batch processing
refactor: Update embedding batch size handling in RAG configuration refactor: add query_doc query caching refactor: update logging statements in generate_chat_completion function change embedding_batch_size to Optional
This commit is contained in:
@@ -12,8 +12,8 @@ from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateEmbeddingsForm,
|
||||
generate_ollama_embeddings,
|
||||
GenerateEmbedForm,
|
||||
generate_ollama_batch_embeddings,
|
||||
)
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
@@ -71,7 +71,7 @@ def query_doc(
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=collection_name,
|
||||
vectors=[query_embedding],
|
||||
vectors=query_embedding,
|
||||
limit=k,
|
||||
)
|
||||
|
||||
@@ -265,19 +265,15 @@ def get_embedding_function(
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
if embedding_engine == "ollama":
|
||||
func = lambda query: generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
model=embedding_model,
|
||||
input=query,
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
func = lambda query: generate_openai_embeddings(
|
||||
@@ -289,13 +285,10 @@ def get_embedding_function(
|
||||
|
||||
def generate_multiple(query, f):
|
||||
if isinstance(query, list):
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(f(query[i : i + embedding_batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return f(query)
|
||||
|
||||
@@ -481,6 +474,21 @@ def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
def generate_ollama_embeddings(
|
||||
model: str, input: list[str]
|
||||
) -> Optional[list[list[float]]]:
|
||||
if isinstance(input, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": input})
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": [input]})
|
||||
)
|
||||
|
||||
return embeddings["embeddings"]
|
||||
|
||||
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
|
||||
|
||||
Reference in New Issue
Block a user