mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
remove the class method rewrite of save_pretrained from patch & fix lint
This commit is contained in:
@@ -143,37 +143,6 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
return kwargs.pop('ori_func')(model_dir, **kwargs)
|
return kwargs.pop('ori_func')(model_dir, **kwargs)
|
||||||
|
|
||||||
def save_pretrained(save_directory: Union[str, os.PathLike],
|
|
||||||
safe_serialization: bool = True,
|
|
||||||
**kwargs):
|
|
||||||
obj = kwargs.pop('obj')
|
|
||||||
push_to_hub = kwargs.pop('push_to_hub', False)
|
|
||||||
|
|
||||||
obj._save_pretrained_origin(
|
|
||||||
obj,
|
|
||||||
save_directory=save_directory,
|
|
||||||
safe_serialization=safe_serialization,
|
|
||||||
push_to_hub=False,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
# Class members may be unpatched, so push_to_hub is done separately here
|
|
||||||
if push_to_hub:
|
|
||||||
from modelscope.hub.push_to_hub import push_to_hub
|
|
||||||
from modelscope.hub.api import HubApi
|
|
||||||
api = HubApi()
|
|
||||||
|
|
||||||
token = kwargs.get('token')
|
|
||||||
commit_message = kwargs.pop('commit_message', None)
|
|
||||||
repo_name = kwargs.pop('repo_id',
|
|
||||||
save_directory.split(os.path.sep)[-1])
|
|
||||||
api.create_repo(repo_name, **kwargs)
|
|
||||||
|
|
||||||
push_to_hub(
|
|
||||||
repo_name=repo_name,
|
|
||||||
output_dir=save_directory,
|
|
||||||
commit_message=commit_message,
|
|
||||||
token=token)
|
|
||||||
|
|
||||||
def get_wrapped_class(
|
def get_wrapped_class(
|
||||||
module_class: 'PreTrainedModel',
|
module_class: 'PreTrainedModel',
|
||||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||||
@@ -254,6 +223,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
from modelscope.hub.push_to_hub import push_to_hub
|
from modelscope.hub.push_to_hub import push_to_hub
|
||||||
from modelscope.hub.api import HubApi
|
from modelscope.hub.api import HubApi
|
||||||
|
from modelscope.hub.repository import Repository
|
||||||
|
|
||||||
token = kwargs.get('token')
|
token = kwargs.get('token')
|
||||||
commit_message = kwargs.pop('commit_message', None)
|
commit_message = kwargs.pop('commit_message', None)
|
||||||
@@ -264,6 +234,8 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
api = HubApi()
|
api = HubApi()
|
||||||
api.login(token)
|
api.login(token)
|
||||||
api.create_repo(repo_name)
|
api.create_repo(repo_name)
|
||||||
|
# clone the repo
|
||||||
|
Repository(save_directory, repo_name)
|
||||||
|
|
||||||
super().save_pretrained(
|
super().save_pretrained(
|
||||||
save_directory=save_directory,
|
save_directory=save_directory,
|
||||||
@@ -367,15 +339,6 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
ori_func=var._get_config_dict_origin,
|
ori_func=var._get_config_dict_origin,
|
||||||
**ignore_file_pattern_kwargs)
|
**ignore_file_pattern_kwargs)
|
||||||
|
|
||||||
if has_save_pretrained and not hasattr(var,
|
|
||||||
'_save_pretrained_origin'):
|
|
||||||
var._save_pretrained_origin = var.save_pretrained
|
|
||||||
var.save_pretrained = partial(
|
|
||||||
save_pretrained,
|
|
||||||
ori_func=var._save_pretrained_origin,
|
|
||||||
obj=var,
|
|
||||||
**ignore_file_pattern_kwargs)
|
|
||||||
|
|
||||||
all_available_modules.append(var)
|
all_available_modules.append(var)
|
||||||
return all_available_modules
|
return all_available_modules
|
||||||
|
|
||||||
@@ -389,7 +352,6 @@ def _unpatch_pretrained_class(all_imported_modules):
|
|||||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||||
has_save_pretrained = hasattr(var, 'save_pretrained')
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
continue
|
continue
|
||||||
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
|
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
|
||||||
@@ -401,9 +363,6 @@ def _unpatch_pretrained_class(all_imported_modules):
|
|||||||
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
|
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
|
||||||
var.get_config_dict = var._get_config_dict_origin
|
var.get_config_dict = var._get_config_dict_origin
|
||||||
delattr(var, '_get_config_dict_origin')
|
delattr(var, '_get_config_dict_origin')
|
||||||
if has_save_pretrained and hasattr(var, '_save_pretrained_origin'):
|
|
||||||
var.save_pretrained = var._save_pretrained_origin
|
|
||||||
delattr(var, '_save_pretrained_origin')
|
|
||||||
|
|
||||||
|
|
||||||
def _patch_hub():
|
def _patch_hub():
|
||||||
|
|||||||
@@ -41,11 +41,12 @@ class HFUtilTest(unittest.TestCase):
|
|||||||
f.write('{}')
|
f.write('{}')
|
||||||
|
|
||||||
self.pipeline_qa_context = r"""
|
self.pipeline_qa_context = r"""
|
||||||
Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a
|
Extractive Question Answering is the task of extracting an answer from a text given a question. An example
|
||||||
question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune
|
of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would
|
||||||
a model on a SQuAD task, you may leverage the examples/pytorch/question-answering/run_squad.py script.
|
like to fine-tune a model on a SQuAD task, you may leverage the
|
||||||
|
examples/pytorch/question-answering/run_squad.py script.
|
||||||
"""
|
"""
|
||||||
self.pipeline_qa_question = "What is a good example of a question answering dataset?"
|
self.pipeline_qa_question = 'What is a good example of a question answering dataset?'
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
logger.info('TearDown')
|
logger.info('TearDown')
|
||||||
@@ -246,8 +247,10 @@ class HFUtilTest(unittest.TestCase):
|
|||||||
def test_pipeline_model_id(self):
|
def test_pipeline_model_id(self):
|
||||||
from modelscope import pipeline
|
from modelscope import pipeline
|
||||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||||
qa = pipeline("question-answering", model=model_id)
|
qa = pipeline('question-answering', model=model_id)
|
||||||
assert qa(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
assert qa(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_pipeline_auto_model(self):
|
def test_pipeline_auto_model(self):
|
||||||
@@ -255,17 +258,21 @@ class HFUtilTest(unittest.TestCase):
|
|||||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||||
model = AutoModelForQuestionAnswering.from_pretrained(model_id)
|
model = AutoModelForQuestionAnswering.from_pretrained(model_id)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
|
qa = pipeline('question-answering', model=model, tokenizer=tokenizer)
|
||||||
assert qa(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
assert qa(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_pipeline_save_pretrained(self):
|
def test_pipeline_save_pretrained(self):
|
||||||
from modelscope import pipeline
|
from modelscope import pipeline
|
||||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||||
|
|
||||||
pipe_ori = pipeline("question-answering", model=model_id)
|
pipe_ori = pipeline('question-answering', model=model_id)
|
||||||
|
|
||||||
result_ori = pipe_ori(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
result_ori = pipe_ori(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
# save_pretrained
|
# save_pretrained
|
||||||
repo_id = 'damotestx/tst_push5'
|
repo_id = 'damotestx/tst_push5'
|
||||||
@@ -282,10 +289,13 @@ class HFUtilTest(unittest.TestCase):
|
|||||||
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
|
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
|
||||||
|
|
||||||
# load from saved
|
# load from saved
|
||||||
pipe_new = pipeline("question-answering", model=repo_id)
|
pipe_new = pipeline('question-answering', model=repo_id)
|
||||||
result_new = pipe_new(question=self.pipeline_qa_question, context=self.pipeline_qa_context)
|
result_new = pipe_new(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
assert result_new == result_ori
|
assert result_new == result_ori
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user