From 4eadf84e1f6f045ae47723d4da5bc98884e4694e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 1 Jan 2026 02:23:42 +0400 Subject: [PATCH] refac --- backend/open_webui/models/channels.py | 81 +++++++++++++++++++-------- 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 55275688a4..dc13d362d8 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -1,4 +1,3 @@ - import json import time import uuid @@ -386,7 +385,9 @@ class ChannelTable: return query - def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]: + def get_channels_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[ChannelModel]: with get_db_context(db) as db: user_group_ids = [ group.id for group in Groups.get_groups_by_member_id(user_id, db=db) @@ -423,7 +424,9 @@ class ChannelTable: all_channels = membership_channels + standard_channels return [ChannelModel.model_validate(c) for c in all_channels] - def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]: + def get_dm_channel_by_user_ids( + self, user_ids: list[str], db: Optional[Session] = None + ) -> Optional[ChannelModel]: with get_db_context(db) as db: # Ensure uniqueness in case a list with duplicates is passed unique_user_ids = list(set(user_ids)) @@ -511,7 +514,9 @@ class ChannelTable: db.commit() return result # number of rows deleted - def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + def is_user_channel_manager( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: with get_db_context(db) as db: # Check if the user is the creator of the channel # or has a 'manager' role in ChannelMember @@ -569,7 +574,9 @@ class ChannelTable: db.commit() return channel_member - def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + def leave_channel( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -604,7 +611,9 @@ class ChannelTable: ) return ChannelMemberModel.model_validate(membership) if membership else None - def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]: + def get_members_by_channel_id( + self, channel_id: str, db: Optional[Session] = None + ) -> list[ChannelMemberModel]: with get_db_context(db) as db: memberships = ( db.query(ChannelMember) @@ -616,7 +625,13 @@ class ChannelTable: for membership in memberships ] - def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool: + def pin_channel( + self, + channel_id: str, + user_id: str, + is_pinned: bool, + db: Optional[Session] = None, + ) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -635,7 +650,9 @@ class ChannelTable: db.commit() return True - def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + def update_member_last_read_at( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -655,7 +672,11 @@ class ChannelTable: return True def update_member_active_status( - self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None + self, + channel_id: str, + user_id: str, + is_active: bool, + db: Optional[Session] = None, ) -> bool: with get_db_context(db) as db: membership = ( @@ -675,7 +696,9 @@ class ChannelTable: db.commit() return True - def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + def is_user_channel_member( + self, channel_id: str, user_id: str, db: Optional[Session] = None + ) -> bool: with get_db_context(db) as db: membership = ( db.query(ChannelMember) @@ -687,7 +710,9 @@ class ChannelTable: ) return membership is not None - def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]: + def get_channel_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[ChannelModel]: try: with get_db_context(db) as db: channel = db.query(Channel).filter(Channel.id == id).first() @@ -695,7 +720,9 @@ class ChannelTable: except Exception: return None - def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]: + def get_channels_by_file_id( + self, file_id: str, db: Optional[Session] = None + ) -> list[ChannelModel]: with get_db_context(db) as db: channel_files = ( db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() @@ -731,7 +758,9 @@ class ChannelTable: return [] # Preload user's group membership - user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)] + user_group_ids = [ + g.id for g in Groups.get_groups_by_member_id(user_id, db=db) + ] allowed_channels = [] @@ -825,9 +854,9 @@ class ChannelTable: ) def update_channel_by_id( - self, id: str, form_data: ChannelForm + self, id: str, form_data: ChannelForm, db: Optional[Session] = None ) -> Optional[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: channel = db.query(Channel).filter(Channel.id == id).first() if not channel: return None @@ -846,9 +875,9 @@ class ChannelTable: return ChannelModel.model_validate(channel) if channel else None def add_file_to_channel_by_id( - self, channel_id: str, file_id: str, user_id: str + self, channel_id: str, file_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelFileModel]: - with get_db() as db: + with get_db_context(db) as db: channel_file = ChannelFileModel( **{ "id": str(uuid.uuid4()), @@ -873,10 +902,14 @@ class ChannelTable: return None def set_file_message_id_in_channel_by_id( - self, channel_id: str, file_id: str, message_id: str + self, + channel_id: str, + file_id: str, + message_id: str, + db: Optional[Session] = None, ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: channel_file = ( db.query(ChannelFile) .filter_by(channel_id=channel_id, file_id=file_id) @@ -893,9 +926,11 @@ class ChannelTable: except Exception: return False - def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool: + def remove_file_from_channel_by_id( + self, channel_id: str, file_id: str, db: Optional[Session] = None + ) -> bool: try: - with get_db() as db: + with get_db_context(db) as db: db.query(ChannelFile).filter_by( channel_id=channel_id, file_id=file_id ).delete() @@ -904,8 +939,8 @@ class ChannelTable: except Exception: return False - def delete_channel_by_id(self, id: str): - with get_db() as db: + def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: db.query(Channel).filter(Channel.id == id).delete() db.commit() return True