Merge pull request #4295 from michaelpoluektov/refactor-tools

refactor: Refactor OpenAI API to use helper functions, silence LSP/linter warnings
This commit is contained in:
Timothy Jaeryang Baek
2024-08-04 14:17:52 +02:00
committed by GitHub
8 changed files with 158 additions and 238 deletions

View File

@@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
import uuid
import time
from utils.task import prompt_template
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
for message in reversed(messages):
@@ -112,6 +114,47 @@ def openai_chat_completion_message_template(model: str, message: str) -> dict:
return template
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters