enh: custom reasoning tags

This commit is contained in:
Timothy Jaeryang Baek
2025-08-27 17:24:16 +04:00
parent 688bfac4ba
commit e39ce16a86
4 changed files with 109 additions and 29 deletions

View File

@@ -111,6 +111,20 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
DEFAULT_REASONING_TAGS = [
("<think>", "</think>"),
("<thinking>", "</thinking>"),
("<reason>", "</reason>"),
("<reasoning>", "</reasoning>"),
("<thought>", "</thought>"),
("<Thought>", "</Thought>"),
("<|begin_of_thought|>", "<|end_of_thought|>"),
("◁think▷", "◁/think▷"),
]
DEFAULT_SOLUTION_TAGS = [("<|begin_of_solution|>", "<|end_of_solution|>")]
DEFAULT_CODE_INTERPRETER_TAGS = [("<code_interpreter>", "</code_interpreter>")]
async def chat_completion_tools_handler(
request: Request, body: dict, extra_params: dict, user: UserModel, models, tools
) -> tuple[dict, dict]:
@@ -694,6 +708,7 @@ def apply_params_to_form_data(form_data, model):
"stream_response": bool,
"stream_delta_chunk_size": int,
"function_calling": str,
"reasoning_tags": list,
"system": str,
}
@@ -1811,27 +1826,23 @@ async def process_chat_response(
}
]
# We might want to disable this by default
DETECT_REASONING = True
DETECT_SOLUTION = True
reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
DETECT_REASONING_TAGS = reasoning_tags_param is not False
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
"code_interpreter", False
)
reasoning_tags = [
("<think>", "</think>"),
("<thinking>", "</thinking>"),
("<reason>", "</reason>"),
("<reasoning>", "</reasoning>"),
("<thought>", "</thought>"),
("<Thought>", "</Thought>"),
("<|begin_of_thought|>", "<|end_of_thought|>"),
("◁think▷", "◁/think▷"),
]
code_interpreter_tags = [("<code_interpreter>", "</code_interpreter>")]
solution_tags = [("<|begin_of_solution|>", "<|end_of_solution|>")]
reasoning_tags = []
if DETECT_REASONING_TAGS:
if (
isinstance(reasoning_tags_param, list)
and len(reasoning_tags_param) == 2
):
reasoning_tags = [
(reasoning_tags_param[0], reasoning_tags_param[1])
]
else:
reasoning_tags = DEFAULT_REASONING_TAGS
try:
for event in events:
@@ -2083,7 +2094,7 @@ async def process_chat_response(
content_blocks[-1]["content"] + value
)
if DETECT_REASONING:
if DETECT_REASONING_TAGS:
content, content_blocks, _ = (
tag_content_handler(
"reasoning",
@@ -2093,11 +2104,20 @@ async def process_chat_response(
)
)
content, content_blocks, _ = (
tag_content_handler(
"solution",
DEFAULT_SOLUTION_TAGS,
content,
content_blocks,
)
)
if DETECT_CODE_INTERPRETER:
content, content_blocks, end = (
tag_content_handler(
"code_interpreter",
code_interpreter_tags,
DEFAULT_CODE_INTERPRETER_TAGS,
content,
content_blocks,
)
@@ -2106,16 +2126,6 @@ async def process_chat_response(
if end:
break
if DETECT_SOLUTION:
content, content_blocks, _ = (
tag_content_handler(
"solution",
solution_tags,
content,
content_blocks,
)
)
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(