This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 19:52:46 -08:00
parent 772f5ccd60
commit fe5519e0a2
5 changed files with 236 additions and 212 deletions

View File

@@ -16,6 +16,22 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def get_task_model_id(
default_model_id: str, task_model: str, task_model_external: str, models
) -> str:
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id]["owned_by"] == "ollama":
if task_model and task_model in models:
task_model_id = task_model
else:
if task_model_external and task_model_external in models:
task_model_id = task_model_external
return task_model_id
def prompt_template(
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
) -> str: