From 2a6c5fdd55f8d3d502cbadcd7dbd59ebdae8bb0c Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Tue, 21 Feb 2023 22:41:14 +0800 Subject: [PATCH] [to #47939677] load only backbone with weights Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11679242 * load backbone with weights directly --- modelscope/models/base/base_model.py | 7 +- modelscope/models/builder.py | 11 +-- .../models/nlp/hf_transformers/backbone.py | 9 +++ tests/models/test_backbone.py | 68 +++++++++++++++++++ 4 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 tests/models/test_backbone.py diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index e0cedc34..18855829 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -6,7 +6,8 @@ from typing import Any, Dict, List, Optional, Union from modelscope.hub.check_model import check_local_model_is_latest from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models.builder import build_model +from modelscope.metainfo import Tasks +from modelscope.models.builder import build_backbone, build_model from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile from modelscope.utils.device import verify_device @@ -129,7 +130,9 @@ class Model(ABC): model_cfg[k] = v if device is not None: model_cfg.device = device - model = build_model(model_cfg, task_name=task_name) + if task_name is Tasks.backbone: + model_cfg.init_backbone = True + model = build_backbone(model_cfg) else: model = build_model(model_cfg, task_name=task_name) diff --git a/modelscope/models/builder.py b/modelscope/models/builder.py index da18edd8..b57fba53 100644 --- a/modelscope/models/builder.py +++ b/modelscope/models/builder.py @@ -54,8 +54,12 @@ def build_backbone(cfg: ConfigDict, default_args: dict = None): cfg (:obj:`ConfigDict`): config dict for backbone object. default_args (dict, optional): Default initialization arguments. """ - try: + if not cfg.get('init_backbone', False): model_dir = cfg.pop('model_dir', None) + else: + model_dir = cfg.get('model_dir', None) + + try: model = build_from_cfg( cfg, BACKBONES, @@ -65,12 +69,11 @@ def build_backbone(cfg: ConfigDict, default_args: dict = None): # Handle backbone that is not in the register group by using transformers AutoModel. # AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM # is limited, thus could be added and registered in Modelscope directly - logger.WARNING( + logger.warning( f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.' ) cfg['type'] = Models.transformers - if model_dir is not None: - cfg['model_dir'] = model_dir + cfg['model_dir'] = model_dir model = build_from_cfg( cfg, BACKBONES, diff --git a/modelscope/models/nlp/hf_transformers/backbone.py b/modelscope/models/nlp/hf_transformers/backbone.py index 60321ae9..5b9a3965 100644 --- a/modelscope/models/nlp/hf_transformers/backbone.py +++ b/modelscope/models/nlp/hf_transformers/backbone.py @@ -91,11 +91,20 @@ class TransformersModel(TorchModel, PreTrainedModel): @classmethod def _instantiate(cls, model_dir=None, **config): + init_backbone = config.pop('init_backbone', False) + + # return the model with pretrained weights + if init_backbone: + model = AutoModel.from_pretrained(model_dir) + return model + + # return the model only config, kwargs = AutoConfig.from_pretrained( model_dir, return_unused_kwargs=True, trust_remote_code=False, **config) + model_mapping = AutoModel._model_mapping if type(config) in model_mapping.keys(): model_class = _get_model_class(config, model_mapping) diff --git a/tests/models/test_backbone.py b/tests/models/test_backbone.py new file mode 100644 index 00000000..411fe722 --- /dev/null +++ b/tests/models/test_backbone.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class BackboneTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.backbone + self.model_id = 'damo/nlp_structbert_backbone_tiny_std' + self.transformer_model = 'bert' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_load_backbone_model_with_ms_backbone(self): + model = Model.from_pretrained( + task=self.task, model_name_or_path=self.model_id) + self.assertEqual(model.__class__.__name__, 'SbertModel') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_load_backbone_model_with_hf_automodel(self): + local_model_dir = snapshot_download(self.model_id) + cfg = Config.from_file( + osp.join(local_model_dir, ModelFile.CONFIGURATION)) + cfg.model = {'type': 'transformers'} + + import json + with open(osp.join(local_model_dir, ModelFile.CONFIG), 'r') as f: + hf_config = json.load(f) + + hf_config['model_type'] = self.transformer_model + + with open(osp.join(local_model_dir, ModelFile.CONFIG), 'w') as f: + json.dump(hf_config, f) + + model = Model.from_pretrained( + task=self.task, model_name_or_path=self.model_id, cfg_dict=cfg) + self.assertEqual(model.__class__.__name__, 'BertModel') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_load_backbone_model_with_hf_automodel_specific_model(self): + self.transformer_model = 'roberta' + + local_model_dir = snapshot_download(self.model_id) + cfg = Config.from_file( + osp.join(local_model_dir, ModelFile.CONFIGURATION)) + cfg.model = {'type': self.transformer_model} + import json + with open(osp.join(local_model_dir, ModelFile.CONFIG), 'r') as f: + hf_config = json.load(f) + + hf_config['model_type'] = self.transformer_model + + with open(osp.join(local_model_dir, ModelFile.CONFIG), 'w') as f: + json.dump(hf_config, f) + + model = Model.from_pretrained( + task=self.task, model_name_or_path=self.model_id, cfg_dict=cfg) + self.assertEqual(model.__class__.__name__, 'RobertaModel') + + +if __name__ == '__main__': + unittest.main()