refac: Refactor functions

This commit is contained in:
Michael Poluektov
2024-07-31 13:35:02 +01:00
parent 9d58bb1c66
commit 3978efd710
4 changed files with 215 additions and 292 deletions

View File

@@ -1,9 +1,6 @@
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from sqlalchemy.orm import Session
from apps.webui.routers import (
auths,
users,
@@ -27,7 +24,6 @@ from utils.task import prompt_template
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
@@ -55,7 +51,7 @@ import uuid
import time
import json
from typing import Iterator, Generator, AsyncGenerator, Optional
from typing import Iterator, Generator, AsyncGenerator
from pydantic import BaseModel
app = FastAPI()
@@ -127,60 +123,60 @@ async def get_status():
}
def get_function_module(pipe_id: str):
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe.id
)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe.id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
function_module = get_function_module(pipe.id)
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
if not function_module.type == "manifold":
continue
manifold_pipes = []
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = (
f"{function_module.name}{manifold_pipe_name}"
)
if hasattr(function_module, "name"):
manifold_pipe_name = f"{function_module.name}{manifold_pipe_name}"
pipe_flag = {"type": pipe.type}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_flag = {"type": pipe.type}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
if hasattr(function_module, "ChatValves"):
@@ -200,162 +196,179 @@ async def get_pipe_models():
return pipe_models
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def get_final_message(form_data: dict, message: str | None = None) -> dict:
choice = {
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
# If message is None, we're dealing with a chunk
if not message:
choice["delta"] = {}
else:
choice["message"] = {"role": "assistant", "content": message}
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"created": int(time.time()),
"model": form_data["model"],
"object": "chat.completion" if message is not None else "chat.completion.chunk",
"choices": [choice],
}
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
print(pipe_id)
return pipe_id
def get_params_dict(pipe, form_data, user, extra_params, function_module):
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params["__user__"] = __user__
return params
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = None
if "metadata" in form_data:
metadata = form_data["metadata"]
del form_data["metadata"]
metadata = form_data.pop("metadata", None)
__event_emitter__ = None
__event_call__ = None
__task__ = None
__event_emitter__ = __event_call__ = __task__ = None
if metadata:
if (
metadata.get("session_id")
and metadata.get("chat_id")
and metadata.get("message_id")
):
__event_emitter__ = await get_event_emitter(metadata)
__event_call__ = await get_event_call(metadata)
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
if metadata.get("task"):
__task__ = metadata.get("task")
if not model_info:
return
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None) is not None:
form_data["temperature"] = float(model_info.params.get("temperature"))
if params:
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
if model_info.params.get("top_p", None):
form_data["top_p"] = int(model_info.params.get("top_p", None))
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
if model_info.params.get("max_tokens", None):
form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None):
form_data["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
form_data["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
form_data["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if form_data.get("messages"):
for message in form_data["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
form_data["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
system = params.get("system", None)
if not system:
return
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
pass
template_params = {}
system = prompt_template(system, **template_params)
# Check if the payload already has a system message
# If not, add a system message to the payload
for message in form_data.get("messages", []):
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
if form_data.get("messages"):
form_data["messages"].insert(0, {"role": "system", "content": system})
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
}
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe_id
)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module(pipe_id)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__event_emitter__" in sig.parameters:
params = {**params, "__event_emitter__": __event_emitter__}
if "__event_call__" in sig.parameters:
params = {**params, "__event_call__": __event_call__}
if "__task__" in sig.parameters:
params = {**params, "__task__": __task__}
params = get_params_dict(pipe, form_data, user, extra_params, function_module)
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
res = await execute_pipe(pipe, params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
@@ -377,107 +390,32 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
yield process_line(form_data, line)
if isinstance(res, AsyncGenerator):
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
yield process_line(form_data, line)
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = get_final_message(form_data)
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
res = await execute_pipe(pipe, params)
if isinstance(res, StreamingResponse):
return res
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, dict):
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
if isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
elif isinstance(res, AsyncGenerator):
async for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
message = await get_message(res)
return get_final_message(form_data, message)
return await job()