feat: preset backend logic

This commit is contained in:
Timothy J. Baek
2024-05-25 02:05:05 -07:00
parent 7d2ab168f1
commit 88d053833d
2 changed files with 206 additions and 56 deletions

View File

@@ -875,15 +875,88 @@ async def generate_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
model_id = get_model_id_from_custom_model_id(form_data.model)
model = model_id
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
payload = {
**form_data.model_dump(exclude_none=True),
}
model_id = form_data.model
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
# TODO: add "stop" back in
# payload["stop"] = model_info.params.get("stop", None)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
if ":" not in model:
model = f"{model}:latest"
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
@@ -893,23 +966,12 @@ async def generate_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(payload)
r = None
# payload = {
# **form_data.model_dump_json(exclude_none=True).encode(),
# "model": model,
# "messages": form_data.messages,
# }
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
def get_request():
nonlocal form_data
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
@@ -918,7 +980,7 @@ async def generate_chat_completion(
def stream_content():
try:
if form_data.stream:
if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
@@ -936,7 +998,7 @@ async def generate_chat_completion(
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=form_data.model_dump_json(exclude_none=True).encode(),
data=json.dumps(payload),
stream=True,
)
@@ -992,14 +1054,56 @@ async def generate_openai_chat_completion(
user=Depends(get_verified_user),
):
payload = {
**form_data.model_dump(exclude_none=True),
}
model_id = form_data.model
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None)
payload["top_p"] = model_info.params.get("top_p", None)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
# TODO: add "stop" back in
# payload["stop"] = model_info.params.get("stop", None)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
model = form_data.model
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
@@ -1012,7 +1116,7 @@ async def generate_openai_chat_completion(
r = None
def get_request():
nonlocal form_data
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
@@ -1021,7 +1125,7 @@ async def generate_openai_chat_completion(
def stream_content():
try:
if form_data.stream:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
@@ -1041,7 +1145,7 @@ async def generate_openai_chat_completion(
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=form_data.model_dump_json(exclude_none=True).encode(),
data=json.dumps(payload),
stream=True,
)