mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
Compatibility for huggingface transformers (#391)
This commit is contained in:
@@ -26,6 +26,9 @@ if TYPE_CHECKING:
|
||||
from .pipelines import Pipeline, pipeline
|
||||
from .utils.hub import read_config, create_model_if_not_exist
|
||||
from .utils.logger import get_logger
|
||||
from .utils.hf_util import AutoConfig, GenerationConfig
|
||||
from .utils.hf_util import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from .utils.hf_util import AutoTokenizer
|
||||
from .msdatasets import MsDataset
|
||||
|
||||
else:
|
||||
@@ -65,6 +68,10 @@ else:
|
||||
'pipelines': ['Pipeline', 'pipeline'],
|
||||
'utils.hub': ['read_config', 'create_model_if_not_exist'],
|
||||
'utils.logger': ['get_logger'],
|
||||
'utils.hf_util': [
|
||||
'AutoConfig', 'GenerationConfig', 'AutoModel',
|
||||
'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer'
|
||||
],
|
||||
'msdatasets': ['MsDataset']
|
||||
}
|
||||
|
||||
|
||||
@@ -88,6 +88,8 @@ class Model(ABC):
|
||||
equal to the model saved.
|
||||
For example, load a `backbone` into a `text-classification` model.
|
||||
Other kwargs will be directly fed into the `model` key, to replace the default configs.
|
||||
use_hf(bool): If set True, will use AutoModel in hf to initialize the model to keep compatibility
|
||||
with huggingface transformers.
|
||||
Returns:
|
||||
A model instance.
|
||||
|
||||
@@ -116,6 +118,11 @@ class Model(ABC):
|
||||
local_model_dir = snapshot_download(
|
||||
model_name_or_path, revision, user_agent=invoked_by)
|
||||
logger.info(f'initialize model from {local_model_dir}')
|
||||
|
||||
if kwargs.pop('use_hf', False):
|
||||
from modelscope import AutoModel
|
||||
return AutoModel.from_pretrained(local_model_dir)
|
||||
|
||||
if cfg_dict is not None:
|
||||
cfg = cfg_dict
|
||||
else:
|
||||
|
||||
178
modelscope/utils/hf_util.py
Normal file
178
modelscope/utils/hf_util.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from transformers import AutoConfig as AutoConfigHF
|
||||
from transformers import AutoModel as AutoModelHF
|
||||
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
|
||||
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
|
||||
from transformers import AutoTokenizer as AutoTokenizerHF
|
||||
from transformers import GenerationConfig as GenerationConfigHF
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import Invoke
|
||||
|
||||
|
||||
def user_agent(invoked_by=None):
|
||||
if invoked_by is None:
|
||||
invoked_by = Invoke.PRETRAINED
|
||||
uagent = '%s/%s' % (Invoke.KEY, invoked_by)
|
||||
return uagent
|
||||
|
||||
|
||||
def patch_tokenizer_base():
|
||||
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedTokenizerBase.from_pretrained = from_pretrained
|
||||
|
||||
|
||||
def patch_model_base():
|
||||
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedModel.from_pretrained = from_pretrained
|
||||
|
||||
|
||||
patch_tokenizer_base()
|
||||
patch_model_base()
|
||||
|
||||
|
||||
class AutoModel(AutoModelHF):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
|
||||
return super().from_pretrained(model_dir, *model_args, **kwargs)
|
||||
|
||||
|
||||
class AutoModelForCausalLM(AutoModelForCausalLMHF):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
|
||||
return super().from_pretrained(model_dir, *model_args, **kwargs)
|
||||
|
||||
|
||||
class AutoModelForSeq2SeqLM(AutoModelForSeq2SeqLMHF):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
|
||||
return super().from_pretrained(model_dir, *model_args, **kwargs)
|
||||
|
||||
|
||||
class AutoTokenizer(AutoTokenizerHF):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return super().from_pretrained(model_dir, *model_args, **kwargs)
|
||||
|
||||
|
||||
class AutoConfig(AutoConfigHF):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.py']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return super().from_pretrained(model_dir, *model_args, **kwargs)
|
||||
|
||||
|
||||
class GenerationConfig(GenerationConfigHF):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.py']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return super().from_pretrained(model_dir, *model_args, **kwargs)
|
||||
@@ -3,6 +3,7 @@
|
||||
# Part of the implementation is borrowed from wimglenn/johnnydep
|
||||
|
||||
import copy
|
||||
import filecmp
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
@@ -28,6 +29,9 @@ logger = get_logger()
|
||||
storage = LocalStorage()
|
||||
|
||||
MODELSCOPE_FILE_DIR = get_default_cache_dir()
|
||||
MODELSCOPE_DYNAMIC_MODULE = 'modelscope_modules'
|
||||
BASE_MODULE_DIR = os.path.join(MODELSCOPE_FILE_DIR, MODELSCOPE_DYNAMIC_MODULE)
|
||||
|
||||
PLUGINS_FILENAME = '.modelscope_plugins'
|
||||
OFFICIAL_PLUGINS = [
|
||||
{
|
||||
@@ -322,6 +326,41 @@ def import_module_from_file(module_name, file_path):
|
||||
return module
|
||||
|
||||
|
||||
def create_module_from_files(file_list, file_prefix, module_name):
|
||||
"""
|
||||
Create a python module from a list of files by copying them to the destination directory.
|
||||
|
||||
Args:
|
||||
file_list (List[str]): List of relative file paths to be copied.
|
||||
file_prefix (str): Path prefix for each file in file_list.
|
||||
module_name (str): Name of the module.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_empty_file(file_path):
|
||||
with open(file_path, 'w') as _:
|
||||
pass
|
||||
|
||||
dest_dir = os.path.join(BASE_MODULE_DIR, module_name)
|
||||
for file_path in file_list:
|
||||
file_dir = os.path.dirname(file_path)
|
||||
target_dir = os.path.join(dest_dir, file_dir)
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
init_file = os.path.join(target_dir, '__init__.py')
|
||||
if not os.path.exists(init_file):
|
||||
create_empty_file(init_file)
|
||||
|
||||
target_file = os.path.join(target_dir, file_path)
|
||||
src_file = os.path.join(file_prefix, file_path)
|
||||
if not os.path.exists(target_file) or not filecmp.cmp(
|
||||
src_file, target_file):
|
||||
shutil.copyfile(src_file, target_file)
|
||||
|
||||
importlib.invalidate_caches()
|
||||
|
||||
|
||||
def import_module_from_model_dir(model_dir):
|
||||
""" import all the necessary module from a model dir
|
||||
|
||||
@@ -340,12 +379,26 @@ def import_module_from_model_dir(model_dir):
|
||||
# install the requirements firstly
|
||||
install_requirements_by_files(requirements)
|
||||
|
||||
# then import the modules
|
||||
import sys
|
||||
sys.path.insert(0, model_dir)
|
||||
for file in file_dirs:
|
||||
module_name = Path(file).stem
|
||||
import_module_from_file(module_name, file)
|
||||
if BASE_MODULE_DIR not in sys.path:
|
||||
sys.path.append(BASE_MODULE_DIR)
|
||||
|
||||
module_name = Path(model_dir).stem
|
||||
|
||||
# in order to keep forward compatibility, we add module path to
|
||||
# sys.path so that submodule can be imported directly as before
|
||||
MODULE_PATH = os.path.join(BASE_MODULE_DIR, module_name)
|
||||
if MODULE_PATH not in sys.path:
|
||||
sys.path.append(MODULE_PATH)
|
||||
|
||||
relative_file_dirs = [
|
||||
file.replace(model_dir.rstrip(os.sep) + os.sep, '')
|
||||
for file in file_dirs
|
||||
]
|
||||
create_module_from_files(relative_file_dirs, model_dir, module_name)
|
||||
for file in relative_file_dirs:
|
||||
submodule = module_name + '.' + file.replace(os.sep, '.').replace(
|
||||
'.py', '')
|
||||
importlib.import_module(submodule)
|
||||
|
||||
|
||||
def install_requirements_by_names(plugins: List[str]):
|
||||
|
||||
41
tests/models/test_model_base.py
Normal file
41
tests/models/test_model_base.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.base import Model
|
||||
|
||||
|
||||
class BaseTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
def test_from_pretrained(self):
|
||||
model = Model.from_pretrained(
|
||||
'baichuan-inc/baichuan-7B', revision='v1.0.5')
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_from_pretrained_hf(self):
|
||||
model = Model.from_pretrained(
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny',
|
||||
use_hf=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
55
tests/utils/test_hf_util.py
Normal file
55
tests/utils/test_hf_util.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoTokenizer, GenerationConfig)
|
||||
|
||||
|
||||
class HFUtilTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_auto_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
'baichuan-inc/Baichuan-13B-Chat',
|
||||
trust_remote_code=True,
|
||||
revision='v1.0.3')
|
||||
self.assertEqual(tokenizer.vocab_size, 64000)
|
||||
self.assertEqual(tokenizer.model_max_length, 4096)
|
||||
self.assertFalse(tokenizer.is_fast)
|
||||
|
||||
def test_auto_model(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'baichuan-inc/baichuan-7B', trust_remote_code=True)
|
||||
self.assertTrue(model is not None)
|
||||
|
||||
def test_auto_config(self):
|
||||
config = AutoConfig.from_pretrained(
|
||||
'baichuan-inc/Baichuan-13B-Chat',
|
||||
trust_remote_code=True,
|
||||
revision='v1.0.3')
|
||||
self.assertEqual(config.model_type, 'baichuan')
|
||||
gen_config = GenerationConfig.from_pretrained(
|
||||
'baichuan-inc/Baichuan-13B-Chat',
|
||||
trust_remote_code=True,
|
||||
revision='v1.0.3')
|
||||
self.assertEqual(gen_config.assistant_token_id, 196)
|
||||
|
||||
def test_transformer_patch(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
'skyline2006/llama-7b', revision='v1.0.1')
|
||||
self.assertIsNotNone(tokenizer)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
'skyline2006/llama-7b', revision='v1.0.1')
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -124,3 +124,7 @@ class PluginTest(unittest.TestCase):
|
||||
|
||||
result = self.plugins_manager.list_plugins(show_all=True)
|
||||
self.assertEqual(len(result.items()), len(OFFICIAL_PLUGINS))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user