feat: action function

This commit is contained in:
Timothy J. Baek
2024-07-11 18:41:00 -07:00
parent 90c3d68f00
commit eb10001eb7
7 changed files with 211 additions and 4 deletions

View File

@@ -926,6 +926,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models():
# TODO: Optimize this function
pipe_models = []
openai_models = []
ollama_models = []
@@ -952,6 +953,14 @@ async def get_all_models():
models = pipe_models + openai_models + ollama_models
global_action_ids = [
function.id for function in Functions.get_global_action_functions()
]
enabled_action_ids = [
function.id
for function in Functions.get_functions_by_type("action", active_only=True)
]
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
@@ -962,9 +971,32 @@ async def get_all_models():
):
model["name"] = custom_model.name
model["info"] = custom_model.model_dump()
action_ids = [] + global_action_ids
if "info" in model and "meta" in model["info"]:
action_ids.extend(model["info"]["meta"].get("actionIds", []))
action_ids = list(set(action_ids))
action_ids = [
action_id
for action_id in action_ids
if action_id in enabled_action_ids
]
model["actions"] = [
{
"id": action_id,
"name": Functions.get_function_by_id(action_id).name,
"description": Functions.get_function_by_id(
action_id
).meta.description,
}
for action_id in action_ids
]
else:
owned_by = "openai"
pipe = None
actions = []
for model in models:
if (
@@ -974,6 +1006,27 @@ async def get_all_models():
owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
action_ids = [] + global_action_ids
if "info" in model and "meta" in model["info"]:
action_ids.extend(model["info"]["meta"].get("actionIds", []))
action_ids = list(set(action_ids))
action_ids = [
action_id
for action_id in action_ids
if action_id in enabled_action_ids
]
actions = [
{
"id": action_id,
"name": Functions.get_function_by_id(action_id).name,
"description": Functions.get_function_by_id(
action_id
).meta.description,
}
for action_id in action_ids
]
break
models.append(
@@ -986,6 +1039,7 @@ async def get_all_models():
"info": custom_model.model_dump(),
"preset": True,
**({"pipe": pipe} if pipe is not None else {}),
"actions": actions,
}
)
@@ -1221,6 +1275,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
return data
@app.post("/api/chat/actions/{action_id}")
async def chat_completed(
action_id: str, form_data: dict, user=Depends(get_verified_user)
):
action = Functions.get_function_by_id(action_id)
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Action not found",
)
data = form_data
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
__event_emitter__ = await get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = await get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):
try:
action = function_module.action
# Get the signature of the function
sig = inspect.signature(action)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": action_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(action):
data = await action(**params)
else:
data = action(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
##################################
#
# Task Endpoints