fix: ongoing chat stop issue

This commit is contained in:
Timothy Jaeryang Baek
2025-04-12 20:51:02 -07:00
parent fa61065c1e
commit f3fe82da80
6 changed files with 179 additions and 104 deletions

View File

@@ -5,16 +5,23 @@ from uuid import uuid4
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {}
def cleanup_task(task_id: str):
def cleanup_task(task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary
if id and task_id in chat_tasks.get(id, []):
chat_tasks[id].remove(task_id)
if not chat_tasks[id]: # If no tasks left for this ID, remove the entry
chat_tasks.pop(id, None)
def create_task(coroutine):
def create_task(coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -22,9 +29,15 @@ def create_task(coroutine):
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id))
task.add_done_callback(lambda t: cleanup_task(task_id, id))
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
if chat_tasks.get(id):
chat_tasks[id].append(task_id)
else:
chat_tasks[id] = [task_id]
return task_id, task
@@ -42,6 +55,13 @@ def list_tasks():
return list(tasks.keys())
def list_task_ids_by_chat_id(id):
"""
List all tasks associated with a specific ID.
"""
return chat_tasks.get(id, [])
async def stop_task(task_id: str):
"""
Cancel a running task and remove it from the global task list.