mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #47939677] load only backbone with weights
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11679242 * load backbone with weights directly
This commit is contained in:
committed by
wenmeng.zwm
parent
ea9bd3cbdf
commit
2a6c5fdd55
@@ -6,7 +6,8 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.check_model import check_local_model_is_latest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
from modelscope.metainfo import Tasks
|
||||
from modelscope.models.builder import build_backbone, build_model
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.device import verify_device
|
||||
@@ -129,7 +130,9 @@ class Model(ABC):
|
||||
model_cfg[k] = v
|
||||
if device is not None:
|
||||
model_cfg.device = device
|
||||
model = build_model(model_cfg, task_name=task_name)
|
||||
if task_name is Tasks.backbone:
|
||||
model_cfg.init_backbone = True
|
||||
model = build_backbone(model_cfg)
|
||||
else:
|
||||
model = build_model(model_cfg, task_name=task_name)
|
||||
|
||||
|
||||
@@ -54,8 +54,12 @@ def build_backbone(cfg: ConfigDict, default_args: dict = None):
|
||||
cfg (:obj:`ConfigDict`): config dict for backbone object.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
try:
|
||||
if not cfg.get('init_backbone', False):
|
||||
model_dir = cfg.pop('model_dir', None)
|
||||
else:
|
||||
model_dir = cfg.get('model_dir', None)
|
||||
|
||||
try:
|
||||
model = build_from_cfg(
|
||||
cfg,
|
||||
BACKBONES,
|
||||
@@ -65,12 +69,11 @@ def build_backbone(cfg: ConfigDict, default_args: dict = None):
|
||||
# Handle backbone that is not in the register group by using transformers AutoModel.
|
||||
# AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM
|
||||
# is limited, thus could be added and registered in Modelscope directly
|
||||
logger.WARNING(
|
||||
logger.warning(
|
||||
f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.'
|
||||
)
|
||||
cfg['type'] = Models.transformers
|
||||
if model_dir is not None:
|
||||
cfg['model_dir'] = model_dir
|
||||
cfg['model_dir'] = model_dir
|
||||
model = build_from_cfg(
|
||||
cfg,
|
||||
BACKBONES,
|
||||
|
||||
@@ -91,11 +91,20 @@ class TransformersModel(TorchModel, PreTrainedModel):
|
||||
|
||||
@classmethod
|
||||
def _instantiate(cls, model_dir=None, **config):
|
||||
init_backbone = config.pop('init_backbone', False)
|
||||
|
||||
# return the model with pretrained weights
|
||||
if init_backbone:
|
||||
model = AutoModel.from_pretrained(model_dir)
|
||||
return model
|
||||
|
||||
# return the model only
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
model_dir,
|
||||
return_unused_kwargs=True,
|
||||
trust_remote_code=False,
|
||||
**config)
|
||||
|
||||
model_mapping = AutoModel._model_mapping
|
||||
if type(config) in model_mapping.keys():
|
||||
model_class = _get_model_class(config, model_mapping)
|
||||
|
||||
68
tests/models/test_backbone.py
Normal file
68
tests/models/test_backbone.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class BackboneTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.backbone
|
||||
self.model_id = 'damo/nlp_structbert_backbone_tiny_std'
|
||||
self.transformer_model = 'bert'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_load_backbone_model_with_ms_backbone(self):
|
||||
model = Model.from_pretrained(
|
||||
task=self.task, model_name_or_path=self.model_id)
|
||||
self.assertEqual(model.__class__.__name__, 'SbertModel')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_load_backbone_model_with_hf_automodel(self):
|
||||
local_model_dir = snapshot_download(self.model_id)
|
||||
cfg = Config.from_file(
|
||||
osp.join(local_model_dir, ModelFile.CONFIGURATION))
|
||||
cfg.model = {'type': 'transformers'}
|
||||
|
||||
import json
|
||||
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'r') as f:
|
||||
hf_config = json.load(f)
|
||||
|
||||
hf_config['model_type'] = self.transformer_model
|
||||
|
||||
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'w') as f:
|
||||
json.dump(hf_config, f)
|
||||
|
||||
model = Model.from_pretrained(
|
||||
task=self.task, model_name_or_path=self.model_id, cfg_dict=cfg)
|
||||
self.assertEqual(model.__class__.__name__, 'BertModel')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_load_backbone_model_with_hf_automodel_specific_model(self):
|
||||
self.transformer_model = 'roberta'
|
||||
|
||||
local_model_dir = snapshot_download(self.model_id)
|
||||
cfg = Config.from_file(
|
||||
osp.join(local_model_dir, ModelFile.CONFIGURATION))
|
||||
cfg.model = {'type': self.transformer_model}
|
||||
import json
|
||||
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'r') as f:
|
||||
hf_config = json.load(f)
|
||||
|
||||
hf_config['model_type'] = self.transformer_model
|
||||
|
||||
with open(osp.join(local_model_dir, ModelFile.CONFIG), 'w') as f:
|
||||
json.dump(hf_config, f)
|
||||
|
||||
model = Model.from_pretrained(
|
||||
task=self.task, model_name_or_path=self.model_id, cfg_dict=cfg)
|
||||
self.assertEqual(model.__class__.__name__, 'RobertaModel')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user