remove the class method rewrite of save_pretrained from patch & fix lint

This commit is contained in:
suluyan
2025-02-11 14:56:55 +08:00
parent 021e912a38
commit 05ce8a7dc7
2 changed files with 25 additions and 56 deletions

View File

@@ -143,37 +143,6 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
**kwargs) **kwargs)
return kwargs.pop('ori_func')(model_dir, **kwargs) return kwargs.pop('ori_func')(model_dir, **kwargs)
def save_pretrained(save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
**kwargs):
obj = kwargs.pop('obj')
push_to_hub = kwargs.pop('push_to_hub', False)
obj._save_pretrained_origin(
obj,
save_directory=save_directory,
safe_serialization=safe_serialization,
push_to_hub=False,
**kwargs)
# Class members may be unpatched, so push_to_hub is done separately here
if push_to_hub:
from modelscope.hub.push_to_hub import push_to_hub
from modelscope.hub.api import HubApi
api = HubApi()
token = kwargs.get('token')
commit_message = kwargs.pop('commit_message', None)
repo_name = kwargs.pop('repo_id',
save_directory.split(os.path.sep)[-1])
api.create_repo(repo_name, **kwargs)
push_to_hub(
repo_name=repo_name,
output_dir=save_directory,
commit_message=commit_message,
token=token)
def get_wrapped_class( def get_wrapped_class(
module_class: 'PreTrainedModel', module_class: 'PreTrainedModel',
ignore_file_pattern: Optional[Union[str, List[str]]] = None, ignore_file_pattern: Optional[Union[str, List[str]]] = None,
@@ -254,6 +223,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
if push_to_hub: if push_to_hub:
from modelscope.hub.push_to_hub import push_to_hub from modelscope.hub.push_to_hub import push_to_hub
from modelscope.hub.api import HubApi from modelscope.hub.api import HubApi
from modelscope.hub.repository import Repository
token = kwargs.get('token') token = kwargs.get('token')
commit_message = kwargs.pop('commit_message', None) commit_message = kwargs.pop('commit_message', None)
@@ -264,6 +234,8 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
api = HubApi() api = HubApi()
api.login(token) api.login(token)
api.create_repo(repo_name) api.create_repo(repo_name)
# clone the repo
Repository(save_directory, repo_name)
super().save_pretrained( super().save_pretrained(
save_directory=save_directory, save_directory=save_directory,
@@ -367,15 +339,6 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
ori_func=var._get_config_dict_origin, ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs) **ignore_file_pattern_kwargs)
if has_save_pretrained and not hasattr(var,
'_save_pretrained_origin'):
var._save_pretrained_origin = var.save_pretrained
var.save_pretrained = partial(
save_pretrained,
ori_func=var._save_pretrained_origin,
obj=var,
**ignore_file_pattern_kwargs)
all_available_modules.append(var) all_available_modules.append(var)
return all_available_modules return all_available_modules
@@ -389,7 +352,6 @@ def _unpatch_pretrained_class(all_imported_modules):
has_from_pretrained = hasattr(var, 'from_pretrained') has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type') has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict') has_get_config_dict = hasattr(var, 'get_config_dict')
has_save_pretrained = hasattr(var, 'save_pretrained')
except ImportError: except ImportError:
continue continue
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'): if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
@@ -401,9 +363,6 @@ def _unpatch_pretrained_class(all_imported_modules):
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'): if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
var.get_config_dict = var._get_config_dict_origin var.get_config_dict = var._get_config_dict_origin
delattr(var, '_get_config_dict_origin') delattr(var, '_get_config_dict_origin')
if has_save_pretrained and hasattr(var, '_save_pretrained_origin'):
var.save_pretrained = var._save_pretrained_origin
delattr(var, '_save_pretrained_origin')
def _patch_hub(): def _patch_hub():

View File

@@ -41,11 +41,12 @@ class HFUtilTest(unittest.TestCase):
f.write('{}') f.write('{}')
self.pipeline_qa_context = r""" self.pipeline_qa_context = r"""
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a Extractive Question Answering is the task of extracting an answer from a text given a question. An example
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would
a model on a SQuAD task, you may leverage the examples/pytorch/question-answering/run_squad.py script. like to fine-tune a model on a SQuAD task, you may leverage the
examples/pytorch/question-answering/run_squad.py script.
""" """
self.pipeline_qa_question = "What is a good example of a question answering dataset?" self.pipeline_qa_question = 'What is a good example of a question answering dataset?'
def tearDown(self): def tearDown(self):
logger.info('TearDown') logger.info('TearDown')
@@ -246,8 +247,10 @@ class HFUtilTest(unittest.TestCase):
def test_pipeline_model_id(self): def test_pipeline_model_id(self):
from modelscope import pipeline from modelscope import pipeline
model_id = 'damotestx/distilbert-base-cased-distilled-squad' model_id = 'damotestx/distilbert-base-cased-distilled-squad'
qa = pipeline("question-answering", model=model_id) qa = pipeline('question-answering', model=model_id)
assert qa(question=self.pipeline_qa_question, context=self.pipeline_qa_context) assert qa(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_pipeline_auto_model(self): def test_pipeline_auto_model(self):
@@ -255,17 +258,21 @@ class HFUtilTest(unittest.TestCase):
model_id = 'damotestx/distilbert-base-cased-distilled-squad' model_id = 'damotestx/distilbert-base-cased-distilled-squad'
model = AutoModelForQuestionAnswering.from_pretrained(model_id) model = AutoModelForQuestionAnswering.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
qa = pipeline("question-answering", model=model, tokenizer=tokenizer) qa = pipeline('question-answering', model=model, tokenizer=tokenizer)
assert qa(question=self.pipeline_qa_question, context=self.pipeline_qa_context) assert qa(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_pipeline_save_pretrained(self): def test_pipeline_save_pretrained(self):
from modelscope import pipeline from modelscope import pipeline
model_id = 'damotestx/distilbert-base-cased-distilled-squad' model_id = 'damotestx/distilbert-base-cased-distilled-squad'
pipe_ori = pipeline("question-answering", model=model_id) pipe_ori = pipeline('question-answering', model=model_id)
result_ori = pipe_ori(question=self.pipeline_qa_question, context=self.pipeline_qa_context) result_ori = pipe_ori(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
# save_pretrained # save_pretrained
repo_id = 'damotestx/tst_push5' repo_id = 'damotestx/tst_push5'
@@ -282,10 +289,13 @@ class HFUtilTest(unittest.TestCase):
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id) pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
# load from saved # load from saved
pipe_new = pipeline("question-answering", model=repo_id) pipe_new = pipeline('question-answering', model=repo_id)
result_new = pipe_new(question=self.pipeline_qa_question, context=self.pipeline_qa_context) result_new = pipe_new(
question=self.pipeline_qa_question,
context=self.pipeline_qa_context)
assert result_new == result_ori assert result_new == result_ori
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()