mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 20:07:49 +01:00
Merge pull request #14370 from daw/feat/add-azure-openai-embeddings-option
feat:Add Azure OpenAI embedding support
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||
import requests
|
||||
import hashlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
@@ -400,12 +401,14 @@ def get_embedding_function(
|
||||
url,
|
||||
key,
|
||||
embedding_batch_size,
|
||||
deployment=None,
|
||||
version=None,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query, prefix=None, user=None: embedding_function.encode(
|
||||
query, **({"prompt": prefix} if prefix else {})
|
||||
).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
||||
func = lambda query, prefix=None, user=None: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
@@ -414,6 +417,8 @@ def get_embedding_function(
|
||||
url=url,
|
||||
key=key,
|
||||
user=user,
|
||||
deployment=deployment,
|
||||
version=version,
|
||||
)
|
||||
|
||||
def generate_multiple(query, prefix, user, func):
|
||||
@@ -697,6 +702,61 @@ def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
def generate_azure_openai_batch_embeddings(
|
||||
deployment: str,
|
||||
texts: list[str],
|
||||
url: str,
|
||||
key: str = "",
|
||||
model: str = "",
|
||||
version: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}"
|
||||
)
|
||||
json_data = {"input": texts, "model": model}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}"
|
||||
|
||||
for _ in range(5):
|
||||
r = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"api-key": key,
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
if r.status_code == 429:
|
||||
retry = float(r.headers.get("Retry-After", "1"))
|
||||
time.sleep(retry)
|
||||
continue
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
@@ -794,6 +854,32 @@ def generate_embeddings(
|
||||
model, [text], url, key, prefix, user
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "azure_openai":
|
||||
deployment = kwargs.get("deployment", "")
|
||||
version = kwargs.get("version", "")
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_azure_openai_batch_embeddings(
|
||||
deployment,
|
||||
text,
|
||||
url,
|
||||
key,
|
||||
model,
|
||||
version,
|
||||
prefix,
|
||||
user,
|
||||
)
|
||||
else:
|
||||
embeddings = generate_azure_openai_batch_embeddings(
|
||||
deployment,
|
||||
[text],
|
||||
url,
|
||||
key,
|
||||
model,
|
||||
version,
|
||||
prefix,
|
||||
user,
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
import operator
|
||||
|
||||
Reference in New Issue
Block a user