Merge remote-tracking branch 'origin' into logit_bias

This commit is contained in:
dannyl1u
2025-02-27 23:48:22 -08:00
181 changed files with 10428 additions and 5218 deletions

View File

@@ -322,78 +322,95 @@ async def chat_web_search_handler(
)
return form_data
searchQuery = queries[0]
all_results = []
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Searching "{{searchQuery}}"',
"query": searchQuery,
"done": False,
},
}
)
try:
results = await process_web_search(
request,
SearchForm(
**{
for searchQuery in queries:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Searching "{{searchQuery}}"',
"query": searchQuery,
}
),
user,
"done": False,
},
}
)
if results:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
try:
results = await process_web_search(
request,
SearchForm(
**{
"query": searchQuery,
"urls": results["filenames"],
"done": True,
},
}
}
),
user=user,
)
files = form_data.get("files", [])
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search_results",
"urls": results["filenames"],
}
)
form_data["files"] = files
else:
if results:
all_results.append(results)
files = form_data.get("files", [])
if results.get("collection_name"):
files.append(
{
"collection_name": results["collection_name"],
"name": searchQuery,
"type": "web_search",
"urls": results["filenames"],
}
)
elif results.get("docs"):
files.append(
{
"docs": results.get("docs", []),
"name": searchQuery,
"type": "web_search",
"urls": results["filenames"],
}
)
form_data["files"] = files
except Exception as e:
log.exception(e)
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"description": 'Error searching "{{searchQuery}}"',
"query": searchQuery,
"done": True,
"error": True,
},
}
)
except Exception as e:
log.exception(e)
if all_results:
urls = []
for results in all_results:
if "filenames" in results:
urls.extend(results["filenames"])
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": 'Error searching "{{searchQuery}}"',
"query": searchQuery,
"description": "Searched {{count}} sites",
"urls": urls,
"done": True,
},
}
)
else:
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"done": True,
"error": True,
},
@@ -503,6 +520,7 @@ async def chat_completion_files_handler(
sources = []
if files := body.get("metadata", {}).get("files", None):
queries = []
try:
queries_response = await generate_queries(
request,
@@ -528,8 +546,8 @@ async def chat_completion_files_handler(
queries_response = {"queries": [queries_response]}
queries = queries_response.get("queries", [])
except Exception as e:
queries = []
except:
pass
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
@@ -541,6 +559,7 @@ async def chat_completion_files_handler(
sources = await loop.run_in_executor(
executor,
lambda: get_sources_from_files(
request=request,
files=files,
queries=queries,
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
@@ -550,9 +569,9 @@ async def chat_completion_files_handler(
reranking_function=request.app.state.rf,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT,
),
)
except Exception as e:
log.exception(e)
@@ -728,6 +747,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Remove files duplicates
if files:
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
@@ -785,8 +805,6 @@ async def process_chat_payload(request, form_data, metadata, user, model):
if len(sources) > 0:
context_string = ""
for source_idx, source in enumerate(sources):
source_id = source.get("source", {}).get("name", "")
if "document" in source:
for doc_idx, doc_context in enumerate(source["document"]):
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n"
@@ -806,7 +824,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
if model.get("owned_by") == "ollama":
form_data["messages"] = prepend_to_first_user_message_content(
rag_template(
request.app.state.config.RAG_TEMPLATE, context_string, prompt
@@ -1038,6 +1056,21 @@ async def process_chat_response(
):
return response
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
"__model__": metadata.get("model"),
}
filter_ids = get_sorted_filter_ids(form_data.get("model"))
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@@ -1117,12 +1150,12 @@ async def process_chat_response(
if reasoning_duration is not None:
if raw:
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
else:
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
else:
if raw:
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
else:
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
@@ -1218,9 +1251,9 @@ async def process_chat_response(
return attributes
if content_blocks[-1]["type"] == "text":
for tag in tags:
for start_tag, end_tag in tags:
# Match start tag e.g., <tag> or <tag attr="value">
start_tag_pattern = rf"<{tag}(\s.*?)?>"
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>"
match = re.search(start_tag_pattern, content)
if match:
attr_content = (
@@ -1253,7 +1286,8 @@ async def process_chat_response(
content_blocks.append(
{
"type": content_type,
"tag": tag,
"start_tag": start_tag,
"end_tag": end_tag,
"attributes": attributes,
"content": "",
"started_at": time.time(),
@@ -1265,9 +1299,10 @@ async def process_chat_response(
break
elif content_blocks[-1]["type"] == content_type:
tag = content_blocks[-1]["tag"]
start_tag = content_blocks[-1]["start_tag"]
end_tag = content_blocks[-1]["end_tag"]
# Match end tag e.g., </tag>
end_tag_pattern = rf"</{tag}>"
end_tag_pattern = rf"<{re.escape(end_tag)}>"
# Check if the content has the end tag
if re.search(end_tag_pattern, content):
@@ -1275,7 +1310,7 @@ async def process_chat_response(
block_content = content_blocks[-1]["content"]
# Strip start and end tags from the content
start_tag_pattern = rf"<{tag}(.*?)>"
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>"
block_content = re.sub(
start_tag_pattern, "", block_content
).strip()
@@ -1340,7 +1375,7 @@ async def process_chat_response(
# Clean processed content
content = re.sub(
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
"",
content,
flags=re.DOTALL,
@@ -1353,7 +1388,22 @@ async def process_chat_response(
)
tool_calls = []
content = message.get("content", "") if message else ""
last_assistant_message = None
try:
if form_data["messages"][-1]["role"] == "assistant":
last_assistant_message = get_last_assistant_message(
form_data["messages"]
)
except Exception as e:
pass
content = (
message.get("content", "")
if message
else last_assistant_message if last_assistant_message else ""
)
content_blocks = [
{
"type": "text",
@@ -1363,19 +1413,24 @@ async def process_chat_response(
# We might want to disable this by default
DETECT_REASONING = True
DETECT_SOLUTION = True
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
"code_interpreter", False
)
reasoning_tags = [
"think",
"thinking",
"reason",
"reasoning",
"thought",
"Thought",
("think", "/think"),
("thinking", "/thinking"),
("reason", "/reason"),
("reasoning", "/reasoning"),
("thought", "/thought"),
("Thought", "/Thought"),
("|begin_of_thought|", "|end_of_thought|"),
]
code_interpreter_tags = ["code_interpreter"]
code_interpreter_tags = [("code_interpreter", "/code_interpreter")]
solution_tags = [("|begin_of_solution|", "|end_of_solution|")]
try:
for event in events:
@@ -1419,119 +1474,154 @@ async def process_chat_response(
try:
data = json.loads(data)
if "selected_model_id" in data:
model_id = data["selected_model_id"]
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": model_id,
},
)
else:
choices = data.get("choices", [])
if not choices:
continue
data, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get("index")
if tool_call_index is not None:
if (
len(response_tool_calls)
<= tool_call_index
):
response_tool_calls.append(
delta_tool_call
)
else:
delta_name = delta_tool_call.get(
"function", {}
).get("name")
delta_arguments = delta_tool_call.get(
"function", {}
).get("arguments")
if delta_name:
response_tool_calls[
tool_call_index
]["function"]["name"] += delta_name
if delta_arguments:
response_tool_calls[
tool_call_index
]["function"][
"arguments"
] += delta_arguments
value = delta.get("content")
if value:
content = f"{content}{value}"
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
content_blocks[-1]["content"] = (
content_blocks[-1]["content"] + value
if data:
if "selected_model_id" in data:
model_id = data["selected_model_id"]
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"selectedModelId": model_id,
},
)
if DETECT_REASONING:
content, content_blocks, _ = (
tag_content_handler(
"reasoning",
reasoning_tags,
content,
content_blocks,
else:
choices = data.get("choices", [])
if not choices:
usage = data.get("usage", {})
if usage:
await event_emitter(
{
"type": "chat:completion",
"data": {
"usage": usage,
},
}
)
continue
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get(
"index"
)
if tool_call_index is not None:
if (
len(response_tool_calls)
<= tool_call_index
):
response_tool_calls.append(
delta_tool_call
)
else:
delta_name = delta_tool_call.get(
"function", {}
).get("name")
delta_arguments = (
delta_tool_call.get(
"function", {}
).get("arguments")
)
if delta_name:
response_tool_calls[
tool_call_index
]["function"][
"name"
] += delta_name
if delta_arguments:
response_tool_calls[
tool_call_index
]["function"][
"arguments"
] += delta_arguments
value = delta.get("content")
if value:
content = f"{content}{value}"
if not content_blocks:
content_blocks.append(
{
"type": "text",
"content": "",
}
)
content_blocks[-1]["content"] = (
content_blocks[-1]["content"] + value
)
if DETECT_CODE_INTERPRETER:
content, content_blocks, end = (
tag_content_handler(
"code_interpreter",
code_interpreter_tags,
content,
content_blocks,
if DETECT_REASONING:
content, content_blocks, _ = (
tag_content_handler(
"reasoning",
reasoning_tags,
content,
content_blocks,
)
)
)
if end:
break
if DETECT_CODE_INTERPRETER:
content, content_blocks, end = (
tag_content_handler(
"code_interpreter",
code_interpreter_tags,
content,
content_blocks,
)
)
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
if end:
break
if DETECT_SOLUTION:
content, content_blocks, _ = (
tag_content_handler(
"solution",
solution_tags,
content,
content_blocks,
)
)
if ENABLE_REALTIME_CHAT_SAVE:
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"content": serialize_content_blocks(
content_blocks
),
},
)
else:
data = {
"content": serialize_content_blocks(
content_blocks
),
},
)
else:
data = {
"content": serialize_content_blocks(
content_blocks
),
}
}
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
await event_emitter(
{
"type": "chat:completion",
"data": data,
}
)
except Exception as e:
done = "data: [DONE]" in line
if done:
@@ -1736,6 +1826,7 @@ async def process_chat_response(
== "password"
else None
),
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
)
else:
output = {
@@ -1829,7 +1920,10 @@ async def process_chat_response(
}
)
print(content_blocks, serialize_content_blocks(content_blocks))
log.info(f"content_blocks={content_blocks}")
log.info(
f"serialize_content_blocks={serialize_content_blocks(content_blocks)}"
)
try:
res = await generate_chat_completion(
@@ -1900,7 +1994,7 @@ async def process_chat_response(
await background_tasks_handler()
except asyncio.CancelledError:
print("Task was cancelled!")
log.warning("Task was cancelled!")
await event_emitter({"type": "task-cancelled"})
if not ENABLE_REALTIME_CHAT_SAVE:
@@ -1921,17 +2015,34 @@ async def process_chat_response(
return {"status": True, "task_id": task_id}
else:
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
return f"data: {item}\n\n"
for event in events:
yield wrap_item(json.dumps(event))
event, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=event,
extra_params=extra_params,
)
if event:
yield wrap_item(json.dumps(event))
async for data in original_generator:
yield data
data, _ = await process_filter_functions(
request=request,
filter_ids=filter_ids,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
if data:
yield data
return StreamingResponse(
stream_wrapper(response.body_iterator, events),