feat: compatible with hf_pipeline (#1221)

compatible with hf_pipeline
This commit is contained in:
suluyana
2025-02-21 15:49:39 +08:00
committed by GitHub
parent e733458746
commit 57044b9c88
8 changed files with 205 additions and 10 deletions

View File

@@ -56,7 +56,7 @@ if TYPE_CHECKING:
AutoModelForPreTraining, AutoModelForTextEncoding,
AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel,
LlamaPreTrainedModel, LlamaForCausalLM)
LlamaPreTrainedModel, LlamaForCausalLM, hf_pipeline)
else:
print(
'transformer is not installed, please install it if you want to use related modules'

View File

@@ -71,6 +71,10 @@ def check_local_model_is_latest(
headers=snapshot_header,
use_cookies=cookies,
)
model_cache = None
# download via non-git method
if not os.path.exists(os.path.join(model_root_path, '.git')):
model_cache = ModelFileSystemCache(model_root_path)
for model_file in model_files:
if model_file['Type'] == 'tree':
continue

View File

@@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton):
git_env = os.environ.copy()
git_env['GIT_TERMINAL_PROMPT'] = '0'
command = [self.git_path, *args]
command = [item for item in command if item]
response = subprocess.run(
command,
stdout=subprocess.PIPE,

View File

@@ -10,6 +10,8 @@ from modelscope.utils.config import ConfigDict, check_config
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
ThirdParty)
from modelscope.utils.hub import read_config
from modelscope.utils.import_utils import is_transformers_available
from modelscope.utils.logger import get_logger
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
@@ -17,6 +19,7 @@ from .base import Pipeline
from .util import is_official_hub_path
PIPELINES = Registry('pipelines')
logger = get_logger()
def normalize_model_input(model,
@@ -72,7 +75,7 @@ def pipeline(task: str = None,
config_file: str = None,
pipeline_name: str = None,
framework: str = None,
device: str = 'gpu',
device: str = None,
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
ignore_file_pattern: List[str] = None,
**kwargs) -> Pipeline:
@@ -109,6 +112,7 @@ def pipeline(task: str = None,
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
pipeline_props = None
if pipeline_name is None:
# get default pipeline for this task
if isinstance(model, str) \
@@ -157,8 +161,11 @@ def pipeline(task: str = None,
if pipeline_name:
pipeline_props = {'type': pipeline_name}
else:
check_config(cfg)
pipeline_props = cfg.pipeline
try:
check_config(cfg)
pipeline_props = cfg.pipeline
except AssertionError as e:
logger.info(str(e))
elif model is not None:
# get pipeline info from Model object
@@ -166,9 +173,13 @@ def pipeline(task: str = None,
if not hasattr(first_model, 'pipeline'):
# model is instantiated by user, we should parse config again
cfg = read_config(first_model.model_dir)
check_config(cfg)
first_model.pipeline = cfg.pipeline
pipeline_props = first_model.pipeline
try:
check_config(cfg)
first_model.pipeline = cfg.pipeline
except AssertionError as e:
logger.info(str(e))
if first_model.__dict__.get('pipeline'):
pipeline_props = first_model.pipeline
else:
pipeline_name, default_model_repo = get_default_pipeline_info(task)
model = normalize_model_input(default_model_repo, model_revision)
@@ -176,6 +187,23 @@ def pipeline(task: str = None,
else:
pipeline_props = {'type': pipeline_name}
if not pipeline_props and is_transformers_available():
try:
from modelscope.utils.hf_util import hf_pipeline
return hf_pipeline(
task=task,
model=model,
framework=framework,
device=device,
**kwargs)
except Exception as e:
logger.error(
'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,'
' but that also failed.')
raise e
if not device:
device = 'gpu'
pipeline_props['model'] = model
pipeline_props['device'] = device
cfg = ConfigDict(pipeline_props)

View File

@@ -1,2 +1,3 @@
from .auto_class import *
from .patcher import patch_context, patch_hub, unpatch_hub
from .pipeline_builder import hf_pipeline

View File

@@ -27,7 +27,8 @@ def get_all_imported_modules():
transformers_include_names = [
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast'
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast',
'Pipeline'
]
peft_include_names = ['.*PeftModel.*', '.*Config']
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
@@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
model_dir, *model_args, **kwargs)
return module_obj
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
**kwargs,
):
push_to_hub = kwargs.pop('push_to_hub', False)
if push_to_hub:
from modelscope.hub.push_to_hub import push_to_hub
from modelscope.hub.api import HubApi
from modelscope.hub.repository import Repository
token = kwargs.get('token')
commit_message = kwargs.pop('commit_message', None)
repo_name = kwargs.pop(
'repo_id',
save_directory.split(os.path.sep)[-1])
api = HubApi()
api.login(token)
api.create_repo(repo_name)
# clone the repo
Repository(save_directory, repo_name)
super().save_pretrained(
save_directory=save_directory,
safe_serialization=safe_serialization,
push_to_hub=False,
**kwargs)
# Class members may be unpatched, so push_to_hub is done separately here
if push_to_hub:
push_to_hub(
repo_name=repo_name,
output_dir=save_directory,
commit_message=commit_message,
token=token)
if not hasattr(module_class, 'from_pretrained'):
del ClassWrapper.from_pretrained
else:
@@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
if not hasattr(module_class, 'get_config_dict'):
del ClassWrapper.get_config_dict
if not hasattr(module_class, 'save_pretrained'):
del ClassWrapper.save_pretrained
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
@@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
has_save_pretrained = hasattr(var, 'save_pretrained')
except: # noqa
continue
if wrap:
# save_pretrained is not a classmethod and cannot be overridden by replacing
# the class method. It requires replacing the class object method.
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
try:
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
if (not has_from_pretrained and not has_get_config_dict
and not has_get_peft_type and not has_save_pretrained):
all_available_modules.append(var)
else:
all_available_modules.append(

View File

@@ -0,0 +1,54 @@
import os
from typing import Optional, Union
import torch
from transformers import Pipeline as PipelineHF
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
from transformers.pipelines import check_task, get_task
from modelscope.hub import snapshot_download
from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub
def _get_hf_device(device):
if isinstance(device, str):
device_name = device.lower()
eles = device_name.split(':')
if eles[0] == 'gpu':
eles = ['cuda'] + eles[1:]
device = ''.join(eles)
return device
def _get_hf_pipeline_class(task, model):
if not task:
task = get_task(model)
normalized_task, targeted_task, task_options = check_task(task)
pipeline_class = targeted_task['impl']
pipeline_class = _patch_pretrained_class([pipeline_class])[0]
return pipeline_class
def hf_pipeline(
task: str = None,
model: Optional[Union[str, 'PreTrainedModel', 'TFPreTrainedModel']] = None,
framework: Optional[str] = None,
device: Optional[Union[int, str, 'torch.device']] = None,
**kwargs,
) -> PipelineHF:
if isinstance(model, str):
if not os.path.exists(model):
model = snapshot_download(model)
framework = 'pt' if framework == 'pytorch' else framework
device = _get_hf_device(device)
pipeline_class = _get_hf_pipeline_class(task, model)
return pipeline(
task=task,
model=model,
framework=framework,
device=device,
pipeline_class=pipeline_class,
**kwargs)

View File

@@ -40,6 +40,14 @@ class HFUtilTest(unittest.TestCase):
with open(self.test_file2, 'w') as f:
f.write('{}')
self.pipeline_qa_context = r"""
Extractive Question Answering is the task of extracting an answer from a text given a question. An example
of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would
like to fine-tune a model on a SQuAD task, you may leverage the
examples/pytorch/question-answering/run_squad.py script.
"""
self.pipeline_qa_question = 'What is a good example of a question answering dataset?'
def tearDown(self):
logger.info('TearDown')
shutil.rmtree(self.model_dir, ignore_errors=True)
@@ -235,6 +243,59 @@ class HFUtilTest(unittest.TestCase):
'Qwen/Qwen1.5-0.5B-Chat', trust_remote_code=True)
model.push_to_hub(self.create_model_name)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_pipeline_model_id(self):
from modelscope import pipeline
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
qa = pipeline('question-answering', model=model_id)
assert qa(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_pipeline_auto_model(self):
from modelscope import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
model = AutoModelForQuestionAnswering.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
qa = pipeline('question-answering', model=model, tokenizer=tokenizer)
assert qa(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_pipeline_save_pretrained(self):
from modelscope import pipeline
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
pipe_ori = pipeline('question-answering', model=model_id)
result_ori = pipe_ori(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
# save_pretrained
repo_id = self.create_model_name
save_dir = './tmp_test_hf_pipeline'
try:
os.system(f'rm -rf {save_dir}')
self.api.delete_model(repo_id)
# wait for delete repo
import time
time.sleep(5)
except Exception:
# if repo not exists
pass
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
# load from saved
pipe_new = pipeline('question-answering', model=repo_id)
result_new = pipe_new(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
assert result_new == result_ori
if __name__ == '__main__':
unittest.main()