This commit is contained in:
Timothy Jaeryang Baek
2026-01-01 02:23:42 +04:00
parent f7f8a263b9
commit 4eadf84e1f

View File

@@ -1,4 +1,3 @@
import json
import time
import uuid
@@ -386,7 +385,9 @@ class ChannelTable:
return query
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
def get_channels_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[ChannelModel]:
with get_db_context(db) as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
@@ -423,7 +424,9 @@ class ChannelTable:
all_channels = membership_channels + standard_channels
return [ChannelModel.model_validate(c) for c in all_channels]
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
def get_dm_channel_by_user_ids(
self, user_ids: list[str], db: Optional[Session] = None
) -> Optional[ChannelModel]:
with get_db_context(db) as db:
# Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids))
@@ -511,7 +514,9 @@ class ChannelTable:
db.commit()
return result # number of rows deleted
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
def is_user_channel_manager(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db:
# Check if the user is the creator of the channel
# or has a 'manager' role in ChannelMember
@@ -569,7 +574,9 @@ class ChannelTable:
db.commit()
return channel_member
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
def leave_channel(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
@@ -604,7 +611,9 @@ class ChannelTable:
)
return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
def get_members_by_channel_id(
self, channel_id: str, db: Optional[Session] = None
) -> list[ChannelMemberModel]:
with get_db_context(db) as db:
memberships = (
db.query(ChannelMember)
@@ -616,7 +625,13 @@ class ChannelTable:
for membership in memberships
]
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool:
def pin_channel(
self,
channel_id: str,
user_id: str,
is_pinned: bool,
db: Optional[Session] = None,
) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
@@ -635,7 +650,9 @@ class ChannelTable:
db.commit()
return True
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
def update_member_last_read_at(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
@@ -655,7 +672,11 @@ class ChannelTable:
return True
def update_member_active_status(
self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None
self,
channel_id: str,
user_id: str,
is_active: bool,
db: Optional[Session] = None,
) -> bool:
with get_db_context(db) as db:
membership = (
@@ -675,7 +696,9 @@ class ChannelTable:
db.commit()
return True
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
def is_user_channel_member(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
@@ -687,7 +710,9 @@ class ChannelTable:
)
return membership is not None
def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
def get_channel_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[ChannelModel]:
try:
with get_db_context(db) as db:
channel = db.query(Channel).filter(Channel.id == id).first()
@@ -695,7 +720,9 @@ class ChannelTable:
except Exception:
return None
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
def get_channels_by_file_id(
self, file_id: str, db: Optional[Session] = None
) -> list[ChannelModel]:
with get_db_context(db) as db:
channel_files = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
@@ -731,7 +758,9 @@ class ChannelTable:
return []
# Preload user's group membership
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
user_group_ids = [
g.id for g in Groups.get_groups_by_member_id(user_id, db=db)
]
allowed_channels = []
@@ -825,9 +854,9 @@ class ChannelTable:
)
def update_channel_by_id(
self, id: str, form_data: ChannelForm
self, id: str, form_data: ChannelForm, db: Optional[Session] = None
) -> Optional[ChannelModel]:
with get_db() as db:
with get_db_context(db) as db:
channel = db.query(Channel).filter(Channel.id == id).first()
if not channel:
return None
@@ -846,9 +875,9 @@ class ChannelTable:
return ChannelModel.model_validate(channel) if channel else None
def add_file_to_channel_by_id(
self, channel_id: str, file_id: str, user_id: str
self, channel_id: str, file_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelFileModel]:
with get_db() as db:
with get_db_context(db) as db:
channel_file = ChannelFileModel(
**{
"id": str(uuid.uuid4()),
@@ -873,10 +902,14 @@ class ChannelTable:
return None
def set_file_message_id_in_channel_by_id(
self, channel_id: str, file_id: str, message_id: str
self,
channel_id: str,
file_id: str,
message_id: str,
db: Optional[Session] = None,
) -> bool:
try:
with get_db() as db:
with get_db_context(db) as db:
channel_file = (
db.query(ChannelFile)
.filter_by(channel_id=channel_id, file_id=file_id)
@@ -893,9 +926,11 @@ class ChannelTable:
except Exception:
return False
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool:
def remove_file_from_channel_by_id(
self, channel_id: str, file_id: str, db: Optional[Session] = None
) -> bool:
try:
with get_db() as db:
with get_db_context(db) as db:
db.query(ChannelFile).filter_by(
channel_id=channel_id, file_id=file_id
).delete()
@@ -904,8 +939,8 @@ class ChannelTable:
except Exception:
return False
def delete_channel_by_id(self, id: str):
with get_db() as db:
def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
db.query(Channel).filter(Channel.id == id).delete()
db.commit()
return True