Add lsf_suffix arg for api.push_model (#545)

This commit is contained in:
aresnow1
2023-09-18 16:27:03 +08:00
committed by GitHub
parent 8e42de3ebc
commit 94b3a9eed7
2 changed files with 8 additions and 2 deletions

View File

@@ -243,7 +243,8 @@ class HubApi:
tag: Optional[str] = None,
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
original_model_id: Optional[str] = None,
ignore_file_pattern: Optional[Union[List[str], str]] = None):
ignore_file_pattern: Optional[Union[List[str], str]] = None,
lfs_suffix: Optional[Union[str, List[str]]] = None):
"""Upload model from a given directory to given repository. A valid model directory
must contain a configuration.json file.
@@ -281,6 +282,7 @@ class HubApi:
branch and push to it.
original_model_id (str, optional): The base model id which this model is trained from
ignore_file_pattern (`Union[List[str], str]`, optional): The file pattern to ignore uploading
lfs_suffix (`List[str]`, optional): File types to use LFS to manage. examples: '*.safetensors'.
Raises:
InvalidParameter: Parameter invalid.
@@ -349,6 +351,10 @@ class HubApi:
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
commit_message = '[automsg] push model %s to hub at %s' % (
model_id, date)
if lfs_suffix is not None:
lfs_suffix_list = [lfs_suffix] if isinstance(lfs_suffix, str) else lfs_suffix
for suffix in lfs_suffix_list:
repo.add_lfs_type(suffix)
repo.push(
commit_message=commit_message,
local_branch=revision,

View File

@@ -105,7 +105,7 @@ class Repository:
examples '*.safetensors'
"""
os.system(
"printf '%s filter=lfs diff=lfs merge=lfs -text\n'>>%s" %
"printf '\n%s filter=lfs diff=lfs merge=lfs -text\n'>>%s" %
(file_name_suffix, os.path.join(self.model_dir, '.gitattributes')))
def push(self,