mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
remove the class method rewrite of save_pretrained from patch & fix lint
This commit is contained in:
@@ -143,37 +143,6 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
**kwargs)
|
||||
return kwargs.pop('ori_func')(model_dir, **kwargs)
|
||||
|
||||
def save_pretrained(save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = True,
|
||||
**kwargs):
|
||||
obj = kwargs.pop('obj')
|
||||
push_to_hub = kwargs.pop('push_to_hub', False)
|
||||
|
||||
obj._save_pretrained_origin(
|
||||
obj,
|
||||
save_directory=save_directory,
|
||||
safe_serialization=safe_serialization,
|
||||
push_to_hub=False,
|
||||
**kwargs)
|
||||
|
||||
# Class members may be unpatched, so push_to_hub is done separately here
|
||||
if push_to_hub:
|
||||
from modelscope.hub.push_to_hub import push_to_hub
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
|
||||
token = kwargs.get('token')
|
||||
commit_message = kwargs.pop('commit_message', None)
|
||||
repo_name = kwargs.pop('repo_id',
|
||||
save_directory.split(os.path.sep)[-1])
|
||||
api.create_repo(repo_name, **kwargs)
|
||||
|
||||
push_to_hub(
|
||||
repo_name=repo_name,
|
||||
output_dir=save_directory,
|
||||
commit_message=commit_message,
|
||||
token=token)
|
||||
|
||||
def get_wrapped_class(
|
||||
module_class: 'PreTrainedModel',
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
@@ -254,6 +223,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
if push_to_hub:
|
||||
from modelscope.hub.push_to_hub import push_to_hub
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.repository import Repository
|
||||
|
||||
token = kwargs.get('token')
|
||||
commit_message = kwargs.pop('commit_message', None)
|
||||
@@ -264,6 +234,8 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
api = HubApi()
|
||||
api.login(token)
|
||||
api.create_repo(repo_name)
|
||||
# clone the repo
|
||||
Repository(save_directory, repo_name)
|
||||
|
||||
super().save_pretrained(
|
||||
save_directory=save_directory,
|
||||
@@ -367,15 +339,6 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
ori_func=var._get_config_dict_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
|
||||
if has_save_pretrained and not hasattr(var,
|
||||
'_save_pretrained_origin'):
|
||||
var._save_pretrained_origin = var.save_pretrained
|
||||
var.save_pretrained = partial(
|
||||
save_pretrained,
|
||||
ori_func=var._save_pretrained_origin,
|
||||
obj=var,
|
||||
**ignore_file_pattern_kwargs)
|
||||
|
||||
all_available_modules.append(var)
|
||||
return all_available_modules
|
||||
|
||||
@@ -389,7 +352,6 @@ def _unpatch_pretrained_class(all_imported_modules):
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
has_save_pretrained = hasattr(var, 'save_pretrained')
|
||||
except ImportError:
|
||||
continue
|
||||
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
|
||||
@@ -401,9 +363,6 @@ def _unpatch_pretrained_class(all_imported_modules):
|
||||
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
|
||||
var.get_config_dict = var._get_config_dict_origin
|
||||
delattr(var, '_get_config_dict_origin')
|
||||
if has_save_pretrained and hasattr(var, '_save_pretrained_origin'):
|
||||
var.save_pretrained = var._save_pretrained_origin
|
||||
delattr(var, '_save_pretrained_origin')
|
||||
|
||||
|
||||
def _patch_hub():
|
||||
|
||||
@@ -41,11 +41,12 @@ class HFUtilTest(unittest.TestCase):
|
||||
f.write('{}')
|
||||
|
||||
self.pipeline_qa_context = r"""
|
||||
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a
|
||||
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune
|
||||
a model on a SQuAD task, you may leverage the examples/pytorch/question-answering/run_squad.py script.
|
||||
Extractive Question Answering is the task of extracting an answer from a text given a question. An example
|
||||
of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would
|
||||
like to fine-tune a model on a SQuAD task, you may leverage the
|
||||
examples/pytorch/question-answering/run_squad.py script.
|
||||
"""
|
||||
self.pipeline_qa_question = "What is a good example of a question answering dataset?"
|
||||
self.pipeline_qa_question = 'What is a good example of a question answering dataset?'
|
||||
|
||||
def tearDown(self):
|
||||
logger.info('TearDown')
|
||||
@@ -246,8 +247,10 @@ class HFUtilTest(unittest.TestCase):
|
||||
def test_pipeline_model_id(self):
|
||||
from modelscope import pipeline
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
qa = pipeline("question-answering", model=model_id)
|
||||
assert qa(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
||||
qa = pipeline('question-answering', model=model_id)
|
||||
assert qa(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_auto_model(self):
|
||||
@@ -255,17 +258,21 @@ class HFUtilTest(unittest.TestCase):
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
|
||||
assert qa(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
||||
qa = pipeline('question-answering', model=model, tokenizer=tokenizer)
|
||||
assert qa(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_save_pretrained(self):
|
||||
from modelscope import pipeline
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
|
||||
pipe_ori = pipeline("question-answering", model=model_id)
|
||||
pipe_ori = pipeline('question-answering', model=model_id)
|
||||
|
||||
result_ori = pipe_ori(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
||||
result_ori = pipe_ori(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
# save_pretrained
|
||||
repo_id = 'damotestx/tst_push5'
|
||||
@@ -282,10 +289,13 @@ class HFUtilTest(unittest.TestCase):
|
||||
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
|
||||
|
||||
# load from saved
|
||||
pipe_new = pipeline("question-answering", model=repo_id)
|
||||
result_new = pipe_new(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
||||
pipe_new = pipeline('question-answering', model=repo_id)
|
||||
result_new = pipe_new(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
assert result_new == result_ori
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user