Compatibility for huggingface transformers (#391)

This commit is contained in:
wenmeng zhou
2023-07-24 20:53:27 +08:00
committed by GitHub
parent ba4b9fc43f
commit 64203e89ee
7 changed files with 351 additions and 6 deletions

View File

@@ -26,6 +26,9 @@ if TYPE_CHECKING:
from .pipelines import Pipeline, pipeline
from .utils.hub import read_config, create_model_if_not_exist
from .utils.logger import get_logger
from .utils.hf_util import AutoConfig, GenerationConfig
from .utils.hf_util import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .utils.hf_util import AutoTokenizer
from .msdatasets import MsDataset
else:
@@ -65,6 +68,10 @@ else:
'pipelines': ['Pipeline', 'pipeline'],
'utils.hub': ['read_config', 'create_model_if_not_exist'],
'utils.logger': ['get_logger'],
'utils.hf_util': [
'AutoConfig', 'GenerationConfig', 'AutoModel',
'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer'
],
'msdatasets': ['MsDataset']
}

View File

@@ -88,6 +88,8 @@ class Model(ABC):
equal to the model saved.
For example, load a `backbone` into a `text-classification` model.
Other kwargs will be directly fed into the `model` key, to replace the default configs.
use_hf(bool): If set True, will use AutoModel in hf to initialize the model to keep compatibility
with huggingface transformers.
Returns:
A model instance.
@@ -116,6 +118,11 @@ class Model(ABC):
local_model_dir = snapshot_download(
model_name_or_path, revision, user_agent=invoked_by)
logger.info(f'initialize model from {local_model_dir}')
if kwargs.pop('use_hf', False):
from modelscope import AutoModel
return AutoModel.from_pretrained(local_model_dir)
if cfg_dict is not None:
cfg = cfg_dict
else:

178
modelscope/utils/hf_util.py Normal file
View File

@@ -0,0 +1,178 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from transformers import AutoConfig as AutoConfigHF
from transformers import AutoModel as AutoModelHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from modelscope import snapshot_download
from modelscope.utils.constant import Invoke
def user_agent(invoked_by=None):
if invoked_by is None:
invoked_by = Invoke.PRETRAINED
uagent = '%s/%s' % (Invoke.KEY, invoked_by)
return uagent
def patch_tokenizer_base():
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedTokenizerBase.from_pretrained = from_pretrained
def patch_model_base():
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedModel.from_pretrained = from_pretrained
patch_tokenizer_base()
patch_model_base()
class AutoModel(AutoModelHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = pretrained_model_name_or_path
return super().from_pretrained(model_dir, *model_args, **kwargs)
class AutoModelForCausalLM(AutoModelForCausalLMHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = pretrained_model_name_or_path
return super().from_pretrained(model_dir, *model_args, **kwargs)
class AutoModelForSeq2SeqLM(AutoModelForSeq2SeqLMHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = pretrained_model_name_or_path
return super().from_pretrained(model_dir, *model_args, **kwargs)
class AutoTokenizer(AutoTokenizerHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return super().from_pretrained(model_dir, *model_args, **kwargs)
class AutoConfig(AutoConfigHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.py']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return super().from_pretrained(model_dir, *model_args, **kwargs)
class GenerationConfig(GenerationConfigHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.py']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return super().from_pretrained(model_dir, *model_args, **kwargs)

View File

@@ -3,6 +3,7 @@
# Part of the implementation is borrowed from wimglenn/johnnydep
import copy
import filecmp
import importlib
import os
import pkgutil
@@ -28,6 +29,9 @@ logger = get_logger()
storage = LocalStorage()
MODELSCOPE_FILE_DIR = get_default_cache_dir()
MODELSCOPE_DYNAMIC_MODULE = 'modelscope_modules'
BASE_MODULE_DIR = os.path.join(MODELSCOPE_FILE_DIR, MODELSCOPE_DYNAMIC_MODULE)
PLUGINS_FILENAME = '.modelscope_plugins'
OFFICIAL_PLUGINS = [
{
@@ -322,6 +326,41 @@ def import_module_from_file(module_name, file_path):
return module
def create_module_from_files(file_list, file_prefix, module_name):
"""
Create a python module from a list of files by copying them to the destination directory.
Args:
file_list (List[str]): List of relative file paths to be copied.
file_prefix (str): Path prefix for each file in file_list.
module_name (str): Name of the module.
Returns:
None
"""
def create_empty_file(file_path):
with open(file_path, 'w') as _:
pass
dest_dir = os.path.join(BASE_MODULE_DIR, module_name)
for file_path in file_list:
file_dir = os.path.dirname(file_path)
target_dir = os.path.join(dest_dir, file_dir)
os.makedirs(target_dir, exist_ok=True)
init_file = os.path.join(target_dir, '__init__.py')
if not os.path.exists(init_file):
create_empty_file(init_file)
target_file = os.path.join(target_dir, file_path)
src_file = os.path.join(file_prefix, file_path)
if not os.path.exists(target_file) or not filecmp.cmp(
src_file, target_file):
shutil.copyfile(src_file, target_file)
importlib.invalidate_caches()
def import_module_from_model_dir(model_dir):
""" import all the necessary module from a model dir
@@ -340,12 +379,26 @@ def import_module_from_model_dir(model_dir):
# install the requirements firstly
install_requirements_by_files(requirements)
# then import the modules
import sys
sys.path.insert(0, model_dir)
for file in file_dirs:
module_name = Path(file).stem
import_module_from_file(module_name, file)
if BASE_MODULE_DIR not in sys.path:
sys.path.append(BASE_MODULE_DIR)
module_name = Path(model_dir).stem
# in order to keep forward compatibility, we add module path to
# sys.path so that submodule can be imported directly as before
MODULE_PATH = os.path.join(BASE_MODULE_DIR, module_name)
if MODULE_PATH not in sys.path:
sys.path.append(MODULE_PATH)
relative_file_dirs = [
file.replace(model_dir.rstrip(os.sep) + os.sep, '')
for file in file_dirs
]
create_module_from_files(relative_file_dirs, model_dir, module_name)
for file in relative_file_dirs:
submodule = module_name + '.' + file.replace(os.sep, '.').replace(
'.py', '')
importlib.import_module(submodule)
def install_requirements_by_names(plugins: List[str]):

View File

@@ -0,0 +1,41 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.base import Model
class BaseTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()
def test_from_pretrained(self):
model = Model.from_pretrained(
'baichuan-inc/baichuan-7B', revision='v1.0.5')
self.assertIsNotNone(model)
def test_from_pretrained_hf(self):
model = Model.from_pretrained(
'damo/nlp_structbert_sentence-similarity_chinese-tiny',
use_hf=True)
self.assertIsNotNone(model)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from transformers import LlamaForCausalLM, LlamaTokenizer
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, GenerationConfig)
class HFUtilTest(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_auto_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
'baichuan-inc/Baichuan-13B-Chat',
trust_remote_code=True,
revision='v1.0.3')
self.assertEqual(tokenizer.vocab_size, 64000)
self.assertEqual(tokenizer.model_max_length, 4096)
self.assertFalse(tokenizer.is_fast)
def test_auto_model(self):
model = AutoModelForCausalLM.from_pretrained(
'baichuan-inc/baichuan-7B', trust_remote_code=True)
self.assertTrue(model is not None)
def test_auto_config(self):
config = AutoConfig.from_pretrained(
'baichuan-inc/Baichuan-13B-Chat',
trust_remote_code=True,
revision='v1.0.3')
self.assertEqual(config.model_type, 'baichuan')
gen_config = GenerationConfig.from_pretrained(
'baichuan-inc/Baichuan-13B-Chat',
trust_remote_code=True,
revision='v1.0.3')
self.assertEqual(gen_config.assistant_token_id, 196)
def test_transformer_patch(self):
tokenizer = LlamaTokenizer.from_pretrained(
'skyline2006/llama-7b', revision='v1.0.1')
self.assertIsNotNone(tokenizer)
model = LlamaForCausalLM.from_pretrained(
'skyline2006/llama-7b', revision='v1.0.1')
self.assertIsNotNone(model)
if __name__ == '__main__':
unittest.main()

View File

@@ -124,3 +124,7 @@ class PluginTest(unittest.TestCase):
result = self.plugins_manager.list_plugins(show_all=True)
self.assertEqual(len(result.items()), len(OFFICIAL_PLUGINS))
if __name__ == '__main__':
unittest.main()