mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Support flex train feature
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12737668
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
import concurrent.futures
|
||||
import os
|
||||
import shutil
|
||||
from multiprocessing import Manager, Process, Value
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import ModelVisibility
|
||||
@@ -11,6 +13,10 @@ from modelscope.utils.logger import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
|
||||
_queues = dict()
|
||||
_flags = dict()
|
||||
_tasks = dict()
|
||||
_manager = None
|
||||
|
||||
|
||||
def _api_push_to_hub(repo_name,
|
||||
@@ -131,3 +137,64 @@ def push_to_hub_async(repo_name,
|
||||
return _executor.submit(_api_push_to_hub, repo_name, output_dir, token,
|
||||
private, commit_message, tag, source_repo,
|
||||
ignore_file_pattern, revision)
|
||||
|
||||
|
||||
def submit_task(q, b):
|
||||
while True:
|
||||
b.value = False
|
||||
item = q.get()
|
||||
logger.info(item)
|
||||
b.value = True
|
||||
if not item.pop('done', False):
|
||||
delete_dir = item.pop('delete_dir', False)
|
||||
output_dir = item.get('output_dir')
|
||||
try:
|
||||
push_to_hub(**item)
|
||||
if delete_dir and os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class UploadStrategy:
|
||||
cancel = 'cancel'
|
||||
wait = 'wait'
|
||||
|
||||
|
||||
def push_to_hub_in_queue(queue_name, strategy=UploadStrategy.cancel, **kwargs):
|
||||
assert queue_name is not None and len(
|
||||
queue_name) > 0, 'Please specify a valid queue name!'
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = Manager()
|
||||
if queue_name not in _queues:
|
||||
_queues[queue_name] = _manager.Queue()
|
||||
_flags[queue_name] = Value('b', False)
|
||||
process = Process(
|
||||
target=submit_task, args=(_queues[queue_name], _flags[queue_name]))
|
||||
process.start()
|
||||
_tasks[queue_name] = process
|
||||
|
||||
queue = _queues[queue_name]
|
||||
flag: Value = _flags[queue_name]
|
||||
if kwargs.get('done', False):
|
||||
queue.put(kwargs)
|
||||
elif flag.value and strategy == UploadStrategy.cancel:
|
||||
logger.error(
|
||||
f'Another uploading is running, '
|
||||
f'this uploading with message {kwargs.get("commit_message")} will be canceled.'
|
||||
)
|
||||
else:
|
||||
queue.put(kwargs)
|
||||
|
||||
|
||||
def wait_for_done(queue_name):
|
||||
process: Process = _tasks.pop(queue_name, None)
|
||||
if process is None:
|
||||
return
|
||||
process.join()
|
||||
|
||||
_queues.pop(queue_name)
|
||||
_flags.pop(queue_name)
|
||||
|
||||
@@ -57,7 +57,7 @@ def update_cfg(cfg: Config) -> Config:
|
||||
key_chain_map[_HOOK_KEY_CHAIN_MAP[key]] = value
|
||||
hook.clear()
|
||||
cfg.train.hooks = list(filter(bool, cfg.train.hooks))
|
||||
cfg.merge_from_dict(key_chain_map)
|
||||
cfg.merge_from_dict(key_chain_map, force=False)
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.hub.check_model import check_model_is_id
|
||||
from modelscope.hub.push_to_hub import push_to_hub_async
|
||||
from modelscope.hub.push_to_hub import (UploadStrategy, push_to_hub_in_queue,
|
||||
wait_for_done)
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
|
||||
@@ -45,7 +47,9 @@ class CheckpointHook(Hook):
|
||||
hub_repo_id (str): The hub repo id.
|
||||
hub_token (str): The token of the modelhub. You can also set the environment variable `MODELSCOPE_API_TOKEN`.
|
||||
private_hub (bool): Whether push to a private hub, default True.
|
||||
hub_revision (str): Which branch to push the model to, default is `master`
|
||||
hub_revision (str): Which branch to push the model to, default is `master`.
|
||||
upload_strategy (str): The action adopted when the previous uploading is not done
|
||||
and the next one is coming, can be `cancel` or `wait`.
|
||||
kwargs:
|
||||
by_epoch (bool): Same with `save_strategy`, but has a higher priority, legacy argument.
|
||||
output_sub_dir (str): The folder under the `save_dir` to save the output checkpoint for inference.
|
||||
@@ -56,6 +60,8 @@ class CheckpointHook(Hook):
|
||||
|
||||
EVAL_RESULT_FILE = 'eval_result.txt'
|
||||
|
||||
PUSH_TO_HUB_QUEUE_NAME = 'train.checkpoint'
|
||||
|
||||
def __init__(self,
|
||||
save_strategy: Optional[str] = CheckpointStrategy.by_epoch,
|
||||
interval: Optional[int] = 0,
|
||||
@@ -68,6 +74,7 @@ class CheckpointHook(Hook):
|
||||
hub_token: Optional[str] = None,
|
||||
private_hub: Optional[bool] = True,
|
||||
hub_revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
upload_strategy: Optional[str] = UploadStrategy.cancel,
|
||||
**kwargs):
|
||||
self.interval = interval
|
||||
self.save_dir = save_dir
|
||||
@@ -89,9 +96,9 @@ class CheckpointHook(Hook):
|
||||
self.hub_token = hub_token
|
||||
self.private_hub = private_hub
|
||||
self.hub_revision = hub_revision
|
||||
self.upload_strategy = upload_strategy
|
||||
self.tag = -1
|
||||
self.is_model_id = None
|
||||
self.push_to_hub_future = None
|
||||
self.max_checkpoint_num = None
|
||||
if max_checkpoint_num is not None:
|
||||
self.max_checkpoint_num = max(int(max_checkpoint_num), 1)
|
||||
@@ -149,13 +156,15 @@ class CheckpointHook(Hook):
|
||||
f'Saving checkpoint at {trainer.iter + 1} iter')
|
||||
self._save_checkpoint(trainer, prefix)
|
||||
if is_master() and self.push_to_hub:
|
||||
if self.push_to_hub_future is not None and not self.push_to_hub_future.done(
|
||||
):
|
||||
self.logger.error(
|
||||
f'Another uploading is running, '
|
||||
f'this uploading with message {prefix} will be canceled.')
|
||||
return
|
||||
self.push_to_hub_future = self._push_to_hub(trainer, prefix)
|
||||
if self.upload_strategy == UploadStrategy.cancel:
|
||||
output_dir = self.output_dir
|
||||
delete_dir = False
|
||||
else:
|
||||
output_dir = self.output_dir + '_upload_' + prefix
|
||||
shutil.copytree(
|
||||
self.output_dir, output_dir, dirs_exist_ok=True)
|
||||
delete_dir = True
|
||||
self._push_to_hub(trainer, prefix, output_dir, delete_dir)
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if self.save_strategy != CheckpointStrategy.by_epoch:
|
||||
@@ -172,32 +181,36 @@ class CheckpointHook(Hook):
|
||||
self._do_save(trainer, CheckpointStrategy.by_step)
|
||||
|
||||
def after_run(self, trainer):
|
||||
if self.push_to_hub_future is not None and not self.push_to_hub_future.done(
|
||||
):
|
||||
self.logger.info('Train finished. Uploading models, waiting...')
|
||||
while not self.push_to_hub_future.done():
|
||||
time.sleep(1)
|
||||
self.logger.info('Uploading models done.')
|
||||
self.logger.info('Train finished. Uploading models, waiting...')
|
||||
push_to_hub_in_queue(
|
||||
self.PUSH_TO_HUB_QUEUE_NAME,
|
||||
strategy=self.upload_strategy,
|
||||
done=True)
|
||||
wait_for_done(self.PUSH_TO_HUB_QUEUE_NAME)
|
||||
self.logger.info('Uploading models done.')
|
||||
|
||||
def _push_to_hub(self, trainer, prefix):
|
||||
def _push_to_hub(self, trainer, prefix, output_dir, delete_dir=False):
|
||||
if self.is_model_id is None:
|
||||
self.is_model_id = check_model_is_id(trainer.input_model_id,
|
||||
self.hub_token)
|
||||
self.tag += 1
|
||||
return push_to_hub_async(
|
||||
self.hub_repo_id,
|
||||
self.output_dir,
|
||||
return push_to_hub_in_queue(
|
||||
self.PUSH_TO_HUB_QUEUE_NAME,
|
||||
strategy=self.upload_strategy,
|
||||
repo_name=self.hub_repo_id,
|
||||
output_dir=output_dir,
|
||||
token=self.hub_token,
|
||||
private=self.private_hub,
|
||||
commit_message=prefix,
|
||||
tag=f'v1.{self.tag}',
|
||||
revision=self.hub_revision,
|
||||
source_repo=trainer.input_model_id if self.is_model_id else '')
|
||||
source_repo=trainer.input_model_id if self.is_model_id else '',
|
||||
delete_dir=delete_dir)
|
||||
|
||||
def save_evaluate_results(self, trainer):
|
||||
with open(os.path.join(self.output_dir, self.EVAL_RESULT_FILE),
|
||||
'w') as f:
|
||||
f.write(str(trainer.metric_values))
|
||||
f.write(json.dumps(trainer.metric_values))
|
||||
|
||||
def _save_checkpoint(self, trainer, prefix):
|
||||
"""Save checkpoint files and remove obsolete ones
|
||||
|
||||
@@ -155,10 +155,10 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.cfg_modify_fn = cfg_modify_fn
|
||||
# add default config
|
||||
merge_cfg(self.cfg)
|
||||
self.cfg = self.rebuild_config(self.cfg)
|
||||
if 'cfg_options' in kwargs:
|
||||
self.cfg.merge_from_dict(kwargs['cfg_options'])
|
||||
self.cfg = update_cfg(self.cfg)
|
||||
self.cfg = self.rebuild_config(self.cfg)
|
||||
|
||||
if isinstance(model, (TorchModel, nn.Module)):
|
||||
self.model = model
|
||||
|
||||
@@ -507,7 +507,7 @@ def build_dataset_from_file(filename):
|
||||
"text2": "sequence2",
|
||||
"label": "label",
|
||||
}
|
||||
"split": 0.8,
|
||||
"usage": 0.8,
|
||||
}
|
||||
]
|
||||
"""
|
||||
@@ -541,16 +541,16 @@ def build_dataset_from_file(filename):
|
||||
lambda x: x,
|
||||
remove_columns=remove_columns,
|
||||
features=new_features).rename_columns(ds['column_mapping'])
|
||||
split = ds['split']
|
||||
if isinstance(split, str):
|
||||
assert split in ('train', 'val')
|
||||
if split == 'train':
|
||||
usage = ds['usage']
|
||||
if isinstance(usage, str):
|
||||
assert usage in ('train', 'val')
|
||||
if usage == 'train':
|
||||
train_set.append(dataset)
|
||||
else:
|
||||
eval_set.append(dataset)
|
||||
else:
|
||||
assert isinstance(split, float) and 0 < split < 1
|
||||
ds_dict = dataset.train_test_split(train_size=split)
|
||||
assert isinstance(usage, float) and 0 < usage < 1
|
||||
ds_dict = dataset.train_test_split(train_size=usage)
|
||||
train_set.append(ds_dict['train'])
|
||||
eval_set.append(ds_dict['test'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user