From 57044b9c88a38b6216687084e26f78a7017e1efb Mon Sep 17 00:00:00 2001 From: suluyana <110878454+suluyana@users.noreply.github.com> Date: Fri, 21 Feb 2025 15:49:39 +0800 Subject: [PATCH] feat: compatible with hf_pipeline (#1221) compatible with hf_pipeline --- modelscope/__init__.py | 2 +- modelscope/hub/check_model.py | 4 ++ modelscope/hub/git.py | 1 + modelscope/pipelines/builder.py | 40 +++++++++++-- modelscope/utils/hf_util/__init__.py | 1 + modelscope/utils/hf_util/patcher.py | 52 ++++++++++++++++- modelscope/utils/hf_util/pipeline_builder.py | 54 +++++++++++++++++ tests/utils/test_hf_util.py | 61 ++++++++++++++++++++ 8 files changed, 205 insertions(+), 10 deletions(-) create mode 100644 modelscope/utils/hf_util/pipeline_builder.py diff --git a/modelscope/__init__.py b/modelscope/__init__.py index a1fbf444..2579ca71 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -56,7 +56,7 @@ if TYPE_CHECKING: AutoModelForPreTraining, AutoModelForTextEncoding, AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration, T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel, - LlamaPreTrainedModel, LlamaForCausalLM) + LlamaPreTrainedModel, LlamaForCausalLM, hf_pipeline) else: print( 'transformer is not installed, please install it if you want to use related modules' diff --git a/modelscope/hub/check_model.py b/modelscope/hub/check_model.py index e41a0a17..6d39c275 100644 --- a/modelscope/hub/check_model.py +++ b/modelscope/hub/check_model.py @@ -71,6 +71,10 @@ def check_local_model_is_latest( headers=snapshot_header, use_cookies=cookies, ) + model_cache = None + # download via non-git method + if not os.path.exists(os.path.join(model_root_path, '.git')): + model_cache = ModelFileSystemCache(model_root_path) for model_file in model_files: if model_file['Type'] == 'tree': continue diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py index 144d9d69..d03ca773 100644 --- a/modelscope/hub/git.py +++ b/modelscope/hub/git.py @@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton): git_env = os.environ.copy() git_env['GIT_TERMINAL_PROMPT'] = '0' command = [self.git_path, *args] + command = [item for item in command if item] response = subprocess.run( command, stdout=subprocess.PIPE, diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 596d6d22..7e5cc6b5 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -10,6 +10,8 @@ from modelscope.utils.config import ConfigDict, check_config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks, ThirdParty) from modelscope.utils.hub import read_config +from modelscope.utils.import_utils import is_transformers_available +from modelscope.utils.logger import get_logger from modelscope.utils.plugins import (register_modelhub_repo, register_plugins_repo) from modelscope.utils.registry import Registry, build_from_cfg @@ -17,6 +19,7 @@ from .base import Pipeline from .util import is_official_hub_path PIPELINES = Registry('pipelines') +logger = get_logger() def normalize_model_input(model, @@ -72,7 +75,7 @@ def pipeline(task: str = None, config_file: str = None, pipeline_name: str = None, framework: str = None, - device: str = 'gpu', + device: str = None, model_revision: Optional[str] = DEFAULT_MODEL_REVISION, ignore_file_pattern: List[str] = None, **kwargs) -> Pipeline: @@ -109,6 +112,7 @@ def pipeline(task: str = None, if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') + pipeline_props = None if pipeline_name is None: # get default pipeline for this task if isinstance(model, str) \ @@ -157,8 +161,11 @@ def pipeline(task: str = None, if pipeline_name: pipeline_props = {'type': pipeline_name} else: - check_config(cfg) - pipeline_props = cfg.pipeline + try: + check_config(cfg) + pipeline_props = cfg.pipeline + except AssertionError as e: + logger.info(str(e)) elif model is not None: # get pipeline info from Model object @@ -166,9 +173,13 @@ def pipeline(task: str = None, if not hasattr(first_model, 'pipeline'): # model is instantiated by user, we should parse config again cfg = read_config(first_model.model_dir) - check_config(cfg) - first_model.pipeline = cfg.pipeline - pipeline_props = first_model.pipeline + try: + check_config(cfg) + first_model.pipeline = cfg.pipeline + except AssertionError as e: + logger.info(str(e)) + if first_model.__dict__.get('pipeline'): + pipeline_props = first_model.pipeline else: pipeline_name, default_model_repo = get_default_pipeline_info(task) model = normalize_model_input(default_model_repo, model_revision) @@ -176,6 +187,23 @@ def pipeline(task: str = None, else: pipeline_props = {'type': pipeline_name} + if not pipeline_props and is_transformers_available(): + try: + from modelscope.utils.hf_util import hf_pipeline + return hf_pipeline( + task=task, + model=model, + framework=framework, + device=device, + **kwargs) + except Exception as e: + logger.error( + 'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,' + ' but that also failed.') + raise e + + if not device: + device = 'gpu' pipeline_props['model'] = model pipeline_props['device'] = device cfg = ConfigDict(pipeline_props) diff --git a/modelscope/utils/hf_util/__init__.py b/modelscope/utils/hf_util/__init__.py index a138ff7a..ac8349c9 100644 --- a/modelscope/utils/hf_util/__init__.py +++ b/modelscope/utils/hf_util/__init__.py @@ -1,2 +1,3 @@ from .auto_class import * from .patcher import patch_context, patch_hub, unpatch_hub +from .pipeline_builder import hf_pipeline diff --git a/modelscope/utils/hf_util/patcher.py b/modelscope/utils/hf_util/patcher.py index 28f8eeb5..6a41a5ce 100644 --- a/modelscope/utils/hf_util/patcher.py +++ b/modelscope/utils/hf_util/patcher.py @@ -27,7 +27,8 @@ def get_all_imported_modules(): transformers_include_names = [ 'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*', 'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig', - 'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast' + 'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast', + 'Pipeline' ] peft_include_names = ['.*PeftModel.*', '.*Config'] diffusers_include_names = ['^(?!TF|Flax).*Pipeline$'] @@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): model_dir, *model_args, **kwargs) return module_obj + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = True, + **kwargs, + ): + push_to_hub = kwargs.pop('push_to_hub', False) + if push_to_hub: + from modelscope.hub.push_to_hub import push_to_hub + from modelscope.hub.api import HubApi + from modelscope.hub.repository import Repository + + token = kwargs.get('token') + commit_message = kwargs.pop('commit_message', None) + repo_name = kwargs.pop( + 'repo_id', + save_directory.split(os.path.sep)[-1]) + + api = HubApi() + api.login(token) + api.create_repo(repo_name) + # clone the repo + Repository(save_directory, repo_name) + + super().save_pretrained( + save_directory=save_directory, + safe_serialization=safe_serialization, + push_to_hub=False, + **kwargs) + + # Class members may be unpatched, so push_to_hub is done separately here + if push_to_hub: + push_to_hub( + repo_name=repo_name, + output_dir=save_directory, + commit_message=commit_message, + token=token) + if not hasattr(module_class, 'from_pretrained'): del ClassWrapper.from_pretrained else: @@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): if not hasattr(module_class, 'get_config_dict'): del ClassWrapper.get_config_dict + if not hasattr(module_class, 'save_pretrained'): + del ClassWrapper.save_pretrained + ClassWrapper.__name__ = module_class.__name__ ClassWrapper.__qualname__ = module_class.__qualname__ return ClassWrapper @@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): has_from_pretrained = hasattr(var, 'from_pretrained') has_get_peft_type = hasattr(var, '_get_peft_type') has_get_config_dict = hasattr(var, 'get_config_dict') + has_save_pretrained = hasattr(var, 'save_pretrained') except: # noqa continue - if wrap: + # save_pretrained is not a classmethod and cannot be overridden by replacing + # the class method. It requires replacing the class object method. + if wrap or ('pipeline' in name.lower() and has_save_pretrained): try: - if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type: + if (not has_from_pretrained and not has_get_config_dict + and not has_get_peft_type and not has_save_pretrained): all_available_modules.append(var) else: all_available_modules.append( diff --git a/modelscope/utils/hf_util/pipeline_builder.py b/modelscope/utils/hf_util/pipeline_builder.py new file mode 100644 index 00000000..5386bead --- /dev/null +++ b/modelscope/utils/hf_util/pipeline_builder.py @@ -0,0 +1,54 @@ +import os +from typing import Optional, Union + +import torch +from transformers import Pipeline as PipelineHF +from transformers import PreTrainedModel, TFPreTrainedModel, pipeline +from transformers.pipelines import check_task, get_task + +from modelscope.hub import snapshot_download +from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub + + +def _get_hf_device(device): + if isinstance(device, str): + device_name = device.lower() + eles = device_name.split(':') + if eles[0] == 'gpu': + eles = ['cuda'] + eles[1:] + device = ''.join(eles) + return device + + +def _get_hf_pipeline_class(task, model): + if not task: + task = get_task(model) + normalized_task, targeted_task, task_options = check_task(task) + pipeline_class = targeted_task['impl'] + pipeline_class = _patch_pretrained_class([pipeline_class])[0] + return pipeline_class + + +def hf_pipeline( + task: str = None, + model: Optional[Union[str, 'PreTrainedModel', 'TFPreTrainedModel']] = None, + framework: Optional[str] = None, + device: Optional[Union[int, str, 'torch.device']] = None, + **kwargs, +) -> PipelineHF: + if isinstance(model, str): + if not os.path.exists(model): + model = snapshot_download(model) + + framework = 'pt' if framework == 'pytorch' else framework + + device = _get_hf_device(device) + pipeline_class = _get_hf_pipeline_class(task, model) + + return pipeline( + task=task, + model=model, + framework=framework, + device=device, + pipeline_class=pipeline_class, + **kwargs) diff --git a/tests/utils/test_hf_util.py b/tests/utils/test_hf_util.py index 9826d991..058e92c7 100644 --- a/tests/utils/test_hf_util.py +++ b/tests/utils/test_hf_util.py @@ -40,6 +40,14 @@ class HFUtilTest(unittest.TestCase): with open(self.test_file2, 'w') as f: f.write('{}') + self.pipeline_qa_context = r""" + Extractive Question Answering is the task of extracting an answer from a text given a question. An example + of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would + like to fine-tune a model on a SQuAD task, you may leverage the + examples/pytorch/question-answering/run_squad.py script. + """ + self.pipeline_qa_question = 'What is a good example of a question answering dataset?' + def tearDown(self): logger.info('TearDown') shutil.rmtree(self.model_dir, ignore_errors=True) @@ -235,6 +243,59 @@ class HFUtilTest(unittest.TestCase): 'Qwen/Qwen1.5-0.5B-Chat', trust_remote_code=True) model.push_to_hub(self.create_model_name) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_pipeline_model_id(self): + from modelscope import pipeline + model_id = 'damotestx/distilbert-base-cased-distilled-squad' + qa = pipeline('question-answering', model=model_id) + assert qa( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_pipeline_auto_model(self): + from modelscope import pipeline, AutoModelForQuestionAnswering, AutoTokenizer + model_id = 'damotestx/distilbert-base-cased-distilled-squad' + model = AutoModelForQuestionAnswering.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + qa = pipeline('question-answering', model=model, tokenizer=tokenizer) + assert qa( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_pipeline_save_pretrained(self): + from modelscope import pipeline + model_id = 'damotestx/distilbert-base-cased-distilled-squad' + + pipe_ori = pipeline('question-answering', model=model_id) + + result_ori = pipe_ori( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + # save_pretrained + repo_id = self.create_model_name + save_dir = './tmp_test_hf_pipeline' + try: + os.system(f'rm -rf {save_dir}') + self.api.delete_model(repo_id) + # wait for delete repo + import time + time.sleep(5) + except Exception: + # if repo not exists + pass + pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id) + + # load from saved + pipe_new = pipeline('question-answering', model=repo_id) + result_new = pipe_new( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + assert result_new == result_ori + if __name__ == '__main__': unittest.main()