Support flex train feature

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12737668
This commit is contained in:
yuze.zyz
2023-05-25 19:39:24 +08:00
committed by xingjun.wxj
parent c08b924968
commit a7a3eb5dc5
5 changed files with 112 additions and 32 deletions

View File

@@ -2,6 +2,8 @@
import concurrent.futures
import os
import shutil
from multiprocessing import Manager, Process, Value
from modelscope.hub.api import HubApi
from modelscope.hub.constants import ModelVisibility
@@ -11,6 +13,10 @@ from modelscope.utils.logger import get_logger
logger = get_logger()
_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
_queues = dict()
_flags = dict()
_tasks = dict()
_manager = None
def _api_push_to_hub(repo_name,
@@ -131,3 +137,64 @@ def push_to_hub_async(repo_name,
return _executor.submit(_api_push_to_hub, repo_name, output_dir, token,
private, commit_message, tag, source_repo,
ignore_file_pattern, revision)
def submit_task(q, b):
while True:
b.value = False
item = q.get()
logger.info(item)
b.value = True
if not item.pop('done', False):
delete_dir = item.pop('delete_dir', False)
output_dir = item.get('output_dir')
try:
push_to_hub(**item)
if delete_dir and os.path.exists(output_dir):
shutil.rmtree(output_dir)
except Exception as e:
logger.error(e)
else:
break
class UploadStrategy:
cancel = 'cancel'
wait = 'wait'
def push_to_hub_in_queue(queue_name, strategy=UploadStrategy.cancel, **kwargs):
assert queue_name is not None and len(
queue_name) > 0, 'Please specify a valid queue name!'
global _manager
if _manager is None:
_manager = Manager()
if queue_name not in _queues:
_queues[queue_name] = _manager.Queue()
_flags[queue_name] = Value('b', False)
process = Process(
target=submit_task, args=(_queues[queue_name], _flags[queue_name]))
process.start()
_tasks[queue_name] = process
queue = _queues[queue_name]
flag: Value = _flags[queue_name]
if kwargs.get('done', False):
queue.put(kwargs)
elif flag.value and strategy == UploadStrategy.cancel:
logger.error(
f'Another uploading is running, '
f'this uploading with message {kwargs.get("commit_message")} will be canceled.'
)
else:
queue.put(kwargs)
def wait_for_done(queue_name):
process: Process = _tasks.pop(queue_name, None)
if process is None:
return
process.join()
_queues.pop(queue_name)
_flags.pop(queue_name)

View File

@@ -57,7 +57,7 @@ def update_cfg(cfg: Config) -> Config:
key_chain_map[_HOOK_KEY_CHAIN_MAP[key]] = value
hook.clear()
cfg.train.hooks = list(filter(bool, cfg.train.hooks))
cfg.merge_from_dict(key_chain_map)
cfg.merge_from_dict(key_chain_map, force=False)
return cfg

View File

@@ -1,14 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
import time
import shutil
from typing import Optional
import json
import numpy as np
import torch
from modelscope.hub.check_model import check_model_is_id
from modelscope.hub.push_to_hub import push_to_hub_async
from modelscope.hub.push_to_hub import (UploadStrategy, push_to_hub_in_queue,
wait_for_done)
from modelscope.metainfo import Hooks
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
@@ -45,7 +47,9 @@ class CheckpointHook(Hook):
hub_repo_id (str): The hub repo id.
hub_token (str): The token of the modelhub. You can also set the environment variable `MODELSCOPE_API_TOKEN`.
private_hub (bool): Whether push to a private hub, default True.
hub_revision (str): Which branch to push the model to, default is `master`
hub_revision (str): Which branch to push the model to, default is `master`.
upload_strategy (str): The action adopted when the previous uploading is not done
and the next one is coming, can be `cancel` or `wait`.
kwargs:
by_epoch (bool): Same with `save_strategy`, but has a higher priority, legacy argument.
output_sub_dir (str): The folder under the `save_dir` to save the output checkpoint for inference.
@@ -56,6 +60,8 @@ class CheckpointHook(Hook):
EVAL_RESULT_FILE = 'eval_result.txt'
PUSH_TO_HUB_QUEUE_NAME = 'train.checkpoint'
def __init__(self,
save_strategy: Optional[str] = CheckpointStrategy.by_epoch,
interval: Optional[int] = 0,
@@ -68,6 +74,7 @@ class CheckpointHook(Hook):
hub_token: Optional[str] = None,
private_hub: Optional[bool] = True,
hub_revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
upload_strategy: Optional[str] = UploadStrategy.cancel,
**kwargs):
self.interval = interval
self.save_dir = save_dir
@@ -89,9 +96,9 @@ class CheckpointHook(Hook):
self.hub_token = hub_token
self.private_hub = private_hub
self.hub_revision = hub_revision
self.upload_strategy = upload_strategy
self.tag = -1
self.is_model_id = None
self.push_to_hub_future = None
self.max_checkpoint_num = None
if max_checkpoint_num is not None:
self.max_checkpoint_num = max(int(max_checkpoint_num), 1)
@@ -149,13 +156,15 @@ class CheckpointHook(Hook):
f'Saving checkpoint at {trainer.iter + 1} iter')
self._save_checkpoint(trainer, prefix)
if is_master() and self.push_to_hub:
if self.push_to_hub_future is not None and not self.push_to_hub_future.done(
):
self.logger.error(
f'Another uploading is running, '
f'this uploading with message {prefix} will be canceled.')
return
self.push_to_hub_future = self._push_to_hub(trainer, prefix)
if self.upload_strategy == UploadStrategy.cancel:
output_dir = self.output_dir
delete_dir = False
else:
output_dir = self.output_dir + '_upload_' + prefix
shutil.copytree(
self.output_dir, output_dir, dirs_exist_ok=True)
delete_dir = True
self._push_to_hub(trainer, prefix, output_dir, delete_dir)
def after_train_epoch(self, trainer):
if self.save_strategy != CheckpointStrategy.by_epoch:
@@ -172,32 +181,36 @@ class CheckpointHook(Hook):
self._do_save(trainer, CheckpointStrategy.by_step)
def after_run(self, trainer):
if self.push_to_hub_future is not None and not self.push_to_hub_future.done(
):
self.logger.info('Train finished. Uploading models, waiting...')
while not self.push_to_hub_future.done():
time.sleep(1)
self.logger.info('Uploading models done.')
self.logger.info('Train finished. Uploading models, waiting...')
push_to_hub_in_queue(
self.PUSH_TO_HUB_QUEUE_NAME,
strategy=self.upload_strategy,
done=True)
wait_for_done(self.PUSH_TO_HUB_QUEUE_NAME)
self.logger.info('Uploading models done.')
def _push_to_hub(self, trainer, prefix):
def _push_to_hub(self, trainer, prefix, output_dir, delete_dir=False):
if self.is_model_id is None:
self.is_model_id = check_model_is_id(trainer.input_model_id,
self.hub_token)
self.tag += 1
return push_to_hub_async(
self.hub_repo_id,
self.output_dir,
return push_to_hub_in_queue(
self.PUSH_TO_HUB_QUEUE_NAME,
strategy=self.upload_strategy,
repo_name=self.hub_repo_id,
output_dir=output_dir,
token=self.hub_token,
private=self.private_hub,
commit_message=prefix,
tag=f'v1.{self.tag}',
revision=self.hub_revision,
source_repo=trainer.input_model_id if self.is_model_id else '')
source_repo=trainer.input_model_id if self.is_model_id else '',
delete_dir=delete_dir)
def save_evaluate_results(self, trainer):
with open(os.path.join(self.output_dir, self.EVAL_RESULT_FILE),
'w') as f:
f.write(str(trainer.metric_values))
f.write(json.dumps(trainer.metric_values))
def _save_checkpoint(self, trainer, prefix):
"""Save checkpoint files and remove obsolete ones

View File

@@ -155,10 +155,10 @@ class EpochBasedTrainer(BaseTrainer):
self.cfg_modify_fn = cfg_modify_fn
# add default config
merge_cfg(self.cfg)
self.cfg = self.rebuild_config(self.cfg)
if 'cfg_options' in kwargs:
self.cfg.merge_from_dict(kwargs['cfg_options'])
self.cfg = update_cfg(self.cfg)
self.cfg = self.rebuild_config(self.cfg)
if isinstance(model, (TorchModel, nn.Module)):
self.model = model

View File

@@ -507,7 +507,7 @@ def build_dataset_from_file(filename):
"text2": "sequence2",
"label": "label",
}
"split": 0.8,
"usage": 0.8,
}
]
"""
@@ -541,16 +541,16 @@ def build_dataset_from_file(filename):
lambda x: x,
remove_columns=remove_columns,
features=new_features).rename_columns(ds['column_mapping'])
split = ds['split']
if isinstance(split, str):
assert split in ('train', 'val')
if split == 'train':
usage = ds['usage']
if isinstance(usage, str):
assert usage in ('train', 'val')
if usage == 'train':
train_set.append(dataset)
else:
eval_set.append(dataset)
else:
assert isinstance(split, float) and 0 < split < 1
ds_dict = dataset.train_test_split(train_size=split)
assert isinstance(usage, float) and 0 < usage < 1
ds_dict = dataset.train_test_split(train_size=usage)
train_set.append(ds_dict['train'])
eval_set.append(ds_dict['test'])