refac: PLEASE follow existing convention

This commit is contained in:
Timothy Jaeryang Baek
2025-05-30 00:34:18 +04:00
parent ff353578db
commit e1e2c096e2
6 changed files with 119 additions and 157 deletions

View File

@@ -401,8 +401,7 @@ def get_embedding_function(
url,
key,
embedding_batch_size,
deployment=None,
version=None,
azure_api_version=None,
):
if embedding_engine == "":
return lambda query, prefix=None, user=None: embedding_function.encode(
@@ -417,8 +416,7 @@ def get_embedding_function(
url=url,
key=key,
user=user,
deployment=deployment,
version=version,
azure_api_version=azure_api_version,
)
def generate_multiple(query, prefix, user, func):
@@ -703,24 +701,23 @@ def generate_openai_batch_embeddings(
def generate_azure_openai_batch_embeddings(
deployment: str,
model: str,
texts: list[str],
url: str,
key: str = "",
model: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {deployment} batch size: {len(texts)}"
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{deployment}/embeddings?api-version={version}"
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
r = requests.post(
@@ -855,27 +852,26 @@ def generate_embeddings(
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
deployment = kwargs.get("deployment", "")
version = kwargs.get("version", "")
azure_api_version = kwargs.get("azure_api_version", "")
if isinstance(text, list):
embeddings = generate_azure_openai_batch_embeddings(
deployment,
model,
text,
url,
key,
model,
version,
azure_api_version,
prefix,
user,
)
else:
embeddings = generate_azure_openai_batch_embeddings(
deployment,
model,
[text],
url,
key,
model,
version,
azure_api_version,
prefix,
user,
)