mirror of
https://github.com/open-webui/open-webui.git
synced 2026-02-24 12:11:56 +01:00
refac
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
@@ -386,7 +385,9 @@ class ChannelTable:
|
||||
|
||||
return query
|
||||
|
||||
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
def get_channels_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
@@ -423,7 +424,9 @@ class ChannelTable:
|
||||
all_channels = membership_channels + standard_channels
|
||||
return [ChannelModel.model_validate(c) for c in all_channels]
|
||||
|
||||
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||
def get_dm_channel_by_user_ids(
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
# Ensure uniqueness in case a list with duplicates is passed
|
||||
unique_user_ids = list(set(user_ids))
|
||||
@@ -511,7 +514,9 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return result # number of rows deleted
|
||||
|
||||
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
def is_user_channel_manager(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
# Check if the user is the creator of the channel
|
||||
# or has a 'manager' role in ChannelMember
|
||||
@@ -569,7 +574,9 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return channel_member
|
||||
|
||||
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
def leave_channel(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
@@ -604,7 +611,9 @@ class ChannelTable:
|
||||
)
|
||||
return ChannelMemberModel.model_validate(membership) if membership else None
|
||||
|
||||
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
|
||||
def get_members_by_channel_id(
|
||||
self, channel_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelMemberModel]:
|
||||
with get_db_context(db) as db:
|
||||
memberships = (
|
||||
db.query(ChannelMember)
|
||||
@@ -616,7 +625,13 @@ class ChannelTable:
|
||||
for membership in memberships
|
||||
]
|
||||
|
||||
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool:
|
||||
def pin_channel(
|
||||
self,
|
||||
channel_id: str,
|
||||
user_id: str,
|
||||
is_pinned: bool,
|
||||
db: Optional[Session] = None,
|
||||
) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
@@ -635,7 +650,9 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
def update_member_last_read_at(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
@@ -655,7 +672,11 @@ class ChannelTable:
|
||||
return True
|
||||
|
||||
def update_member_active_status(
|
||||
self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None
|
||||
self,
|
||||
channel_id: str,
|
||||
user_id: str,
|
||||
is_active: bool,
|
||||
db: Optional[Session] = None,
|
||||
) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
@@ -675,7 +696,9 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
def is_user_channel_member(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
@@ -687,7 +710,9 @@ class ChannelTable:
|
||||
)
|
||||
return membership is not None
|
||||
|
||||
def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||
def get_channel_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
@@ -695,7 +720,9 @@ class ChannelTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
def get_channels_by_file_id(
|
||||
self, file_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
channel_files = (
|
||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||
@@ -731,7 +758,9 @@ class ChannelTable:
|
||||
return []
|
||||
|
||||
# Preload user's group membership
|
||||
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||
user_group_ids = [
|
||||
g.id for g in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
]
|
||||
|
||||
allowed_channels = []
|
||||
|
||||
@@ -825,9 +854,9 @@ class ChannelTable:
|
||||
)
|
||||
|
||||
def update_channel_by_id(
|
||||
self, id: str, form_data: ChannelForm
|
||||
self, id: str, form_data: ChannelForm, db: Optional[Session] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
if not channel:
|
||||
return None
|
||||
@@ -846,9 +875,9 @@ class ChannelTable:
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
|
||||
def add_file_to_channel_by_id(
|
||||
self, channel_id: str, file_id: str, user_id: str
|
||||
self, channel_id: str, file_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelFileModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
channel_file = ChannelFileModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
@@ -873,10 +902,14 @@ class ChannelTable:
|
||||
return None
|
||||
|
||||
def set_file_message_id_in_channel_by_id(
|
||||
self, channel_id: str, file_id: str, message_id: str
|
||||
self,
|
||||
channel_id: str,
|
||||
file_id: str,
|
||||
message_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
channel_file = (
|
||||
db.query(ChannelFile)
|
||||
.filter_by(channel_id=channel_id, file_id=file_id)
|
||||
@@ -893,9 +926,11 @@ class ChannelTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool:
|
||||
def remove_file_from_channel_by_id(
|
||||
self, channel_id: str, file_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ChannelFile).filter_by(
|
||||
channel_id=channel_id, file_id=file_id
|
||||
).delete()
|
||||
@@ -904,8 +939,8 @@ class ChannelTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_channel_by_id(self, id: str):
|
||||
with get_db() as db:
|
||||
def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Channel).filter(Channel.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user