[to #47939677] load only backbone with weights

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11679242

* load backbone with weights directly
This commit is contained in:
zhangzhicheng.zzc
2023-02-21 22:41:14 +08:00
committed by wenmeng.zwm
parent ea9bd3cbdf
commit 2a6c5fdd55
4 changed files with 89 additions and 6 deletions

View File

@@ -6,7 +6,8 @@ from typing import Any, Dict, List, Optional, Union
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
from modelscope.metainfo import Tasks
from modelscope.models.builder import build_backbone, build_model
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
from modelscope.utils.device import verify_device
@@ -129,7 +130,9 @@ class Model(ABC):
model_cfg[k] = v
if device is not None:
model_cfg.device = device
model = build_model(model_cfg, task_name=task_name)
if task_name is Tasks.backbone:
model_cfg.init_backbone = True
model = build_backbone(model_cfg)
else:
model = build_model(model_cfg, task_name=task_name)

View File

@@ -54,8 +54,12 @@ def build_backbone(cfg: ConfigDict, default_args: dict = None):
cfg (:obj:`ConfigDict`): config dict for backbone object.
default_args (dict, optional): Default initialization arguments.
"""
try:
if not cfg.get('init_backbone', False):
model_dir = cfg.pop('model_dir', None)
else:
model_dir = cfg.get('model_dir', None)
try:
model = build_from_cfg(
cfg,
BACKBONES,
@@ -65,12 +69,11 @@ def build_backbone(cfg: ConfigDict, default_args: dict = None):
# Handle backbone that is not in the register group by using transformers AutoModel.
# AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM
# is limited, thus could be added and registered in Modelscope directly
logger.WARNING(
logger.warning(
f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.'
)
cfg['type'] = Models.transformers
if model_dir is not None:
cfg['model_dir'] = model_dir
cfg['model_dir'] = model_dir
model = build_from_cfg(
cfg,
BACKBONES,

View File

@@ -91,11 +91,20 @@ class TransformersModel(TorchModel, PreTrainedModel):
@classmethod
def _instantiate(cls, model_dir=None, **config):
init_backbone = config.pop('init_backbone', False)
# return the model with pretrained weights
if init_backbone:
model = AutoModel.from_pretrained(model_dir)
return model
# return the model only
config, kwargs = AutoConfig.from_pretrained(
model_dir,
return_unused_kwargs=True,
trust_remote_code=False,
**config)
model_mapping = AutoModel._model_mapping
if type(config) in model_mapping.keys():
model_class = _get_model_class(config, model_mapping)

View File

@@ -0,0 +1,68 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level
class BackboneTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.backbone
self.model_id = 'damo/nlp_structbert_backbone_tiny_std'
self.transformer_model = 'bert'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_load_backbone_model_with_ms_backbone(self):
model = Model.from_pretrained(
task=self.task, model_name_or_path=self.model_id)
self.assertEqual(model.__class__.__name__, 'SbertModel')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_load_backbone_model_with_hf_automodel(self):
local_model_dir = snapshot_download(self.model_id)
cfg = Config.from_file(
osp.join(local_model_dir, ModelFile.CONFIGURATION))
cfg.model = {'type': 'transformers'}
import json
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'r') as f:
hf_config = json.load(f)
hf_config['model_type'] = self.transformer_model
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'w') as f:
json.dump(hf_config, f)
model = Model.from_pretrained(
task=self.task, model_name_or_path=self.model_id, cfg_dict=cfg)
self.assertEqual(model.__class__.__name__, 'BertModel')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_load_backbone_model_with_hf_automodel_specific_model(self):
self.transformer_model = 'roberta'
local_model_dir = snapshot_download(self.model_id)
cfg = Config.from_file(
osp.join(local_model_dir, ModelFile.CONFIGURATION))
cfg.model = {'type': self.transformer_model}
import json
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'r') as f:
hf_config = json.load(f)
hf_config['model_type'] = self.transformer_model
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'w') as f:
json.dump(hf_config, f)
model = Model.from_pretrained(
task=self.task, model_name_or_path=self.model_id, cfg_dict=cfg)
self.assertEqual(model.__class__.__name__, 'RobertaModel')
if __name__ == '__main__':
unittest.main()