Merge pull request #14370 from daw/feat/add-azure-openai-embeddings-option

feat:Add Azure OpenAI embedding support
This commit is contained in:
Tim Jaeryang Baek
2025-05-30 00:18:55 +04:00
committed by GitHub
6 changed files with 315 additions and 51 deletions

View File

@@ -5,6 +5,7 @@ from typing import Optional, Union
import requests
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
@@ -400,12 +401,14 @@ def get_embedding_function(
url,
key,
embedding_batch_size,
deployment=None,
version=None,
):
if embedding_engine == "":
return lambda query, prefix=None, user=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
elif embedding_engine in ["ollama", "openai"]:
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
@@ -414,6 +417,8 @@ def get_embedding_function(
url=url,
key=key,
user=user,
deployment=deployment,
version=version,
)
def generate_multiple(query, prefix, user, func):
@@ -697,6 +702,61 @@ def generate_openai_batch_embeddings(
return None
def generate_azure_openai_batch_embeddings(
deployment: str,
texts: list[str],
url: str,
key: str = "",
model: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}"
for _ in range(5):
r = requests.post(
url,
headers={
"Content-Type": "application/json",
"api-key": key,
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
@@ -794,6 +854,32 @@ def generate_embeddings(
model, [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
deployment = kwargs.get("deployment", "")
version = kwargs.get("version", "")
if isinstance(text, list):
embeddings = generate_azure_openai_batch_embeddings(
deployment,
text,
url,
key,
model,
version,
prefix,
user,
)
else:
embeddings = generate_azure_openai_batch_embeddings(
deployment,
[text],
url,
key,
model,
version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
import operator