enh: typing indicator

This commit is contained in:
Timothy Jaeryang Baek
2024-12-26 21:51:09 -08:00
parent 4f93ecf519
commit 6ff6d57507
3 changed files with 137 additions and 46 deletions

View File

@@ -4,7 +4,7 @@ import logging
import sys
import time
from open_webui.models.users import Users
from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats
@@ -152,7 +152,7 @@ async def connect(sid, environ, auth):
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
SESSION_POOL[sid] = user.model_dump()
if user.id in USER_POOL:
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
else:
@@ -178,7 +178,7 @@ async def user_join(sid, data):
if not user:
return
SESSION_POOL[sid] = user.id
SESSION_POOL[sid] = user.model_dump()
if user.id in USER_POOL:
USER_POOL[user.id] = USER_POOL[user.id] + [sid]
else:
@@ -217,22 +217,45 @@ async def join_channel(sid, data):
await sio.enter_room(sid, f"channel:{channel.id}")
@sio.on("channel-events")
async def channel_events(sid, data):
room = f"channel:{data['channel_id']}"
participants = sio.manager.get_participants(
namespace="/",
room=room,
)
sids = [sid for sid, _ in participants]
if sid not in sids:
return
event_data = data["data"]
event_type = event_data["type"]
if event_type == "typing":
await sio.emit(
"channel-events",
{
"channel_id": data["channel_id"],
"data": event_data,
"user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
},
room=room,
)
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(USER_POOL.items())})
@sio.on("chat")
async def chat(sid, data):
print("chat", sid, SESSION_POOL[sid], data)
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
user_id = SESSION_POOL[sid]
user = SESSION_POOL[sid]
del SESSION_POOL[sid]
user_id = user["id"]
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
if len(USER_POOL[user_id]) == 0:
@@ -289,7 +312,10 @@ def get_event_call(request_info):
def get_user_id_from_session_pool(sid):
return SESSION_POOL.get(sid)
user = SESSION_POOL.get(sid)
if user:
return user["id"]
return None
def get_user_ids_from_room(room):
@@ -299,6 +325,8 @@ def get_user_ids_from_room(room):
)
active_user_ids = list(
set([SESSION_POOL.get(session_id[0]) for session_id in active_session_ids])
set(
[SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids]
)
)
return active_user_ids