mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
add command line tool
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11661608 * add command line tool * add unittest * change absolute import to relative import for test case * mv test_util to package
This commit is contained in:
0
modelscope/cli/__init__.py
Normal file
0
modelscope/cli/__init__.py
Normal file
20
modelscope/cli/base.py
Normal file
20
modelscope/cli/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
class CLICommand(ABC):
|
||||
"""
|
||||
Base class for command line tool.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def execute(self):
|
||||
raise NotImplementedError()
|
||||
26
modelscope/cli/cli.py
Normal file
26
modelscope/cli/cli.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
|
||||
from modelscope.cli.download import DownloadCMD
|
||||
|
||||
|
||||
def run_cmd():
|
||||
parser = argparse.ArgumentParser(
|
||||
'ModelScope Command Line tool', usage='modelscope <command> [<args>]')
|
||||
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
||||
|
||||
DownloadCMD.define_args(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not hasattr(args, 'func'):
|
||||
parser.print_help()
|
||||
exit(1)
|
||||
|
||||
cmd = args.func(args)
|
||||
cmd.execute()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_cmd()
|
||||
44
modelscope/cli/download.py
Normal file
44
modelscope/cli/download.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
""" Fuction which will be called for a specific sub parser.
|
||||
"""
|
||||
return DownloadCMD(args)
|
||||
|
||||
|
||||
class DownloadCMD(CLICommand):
|
||||
name = 'download'
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
""" define args for download command.
|
||||
"""
|
||||
parser = parsers.add_parser(DownloadCMD.name)
|
||||
parser.add_argument(
|
||||
'model', type=str, help='Name of the model to be downloaded.')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Revision of the model.')
|
||||
parser.add_argument(
|
||||
'--cache_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Cache directory to save model.')
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
cache_dir=self.args.cache_dir,
|
||||
revision=self.args.revision)
|
||||
@@ -13,15 +13,29 @@ import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from os.path import expanduser
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
|
||||
TEST_LEVEL = 2
|
||||
TEST_LEVEL_STR = 'TEST_LEVEL'
|
||||
|
||||
# for user citest and sdkdev
|
||||
TEST_ACCESS_TOKEN1 = os.environ['TEST_ACCESS_TOKEN_CITEST']
|
||||
TEST_ACCESS_TOKEN2 = os.environ['TEST_ACCESS_TOKEN_SDKDEV']
|
||||
|
||||
TEST_MODEL_CHINESE_NAME = '内部测试模型'
|
||||
TEST_MODEL_ORG = 'citest'
|
||||
|
||||
|
||||
def delete_credential():
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
shutil.rmtree(path_credential, ignore_errors=True)
|
||||
|
||||
|
||||
def test_level():
|
||||
global TEST_LEVEL
|
||||
|
||||
3
setup.py
3
setup.py
@@ -223,5 +223,8 @@ if __name__ == '__main__':
|
||||
tests_require=parse_requirements('requirements/tests.txt'),
|
||||
install_requires=install_requires,
|
||||
extras_require=extra_requires,
|
||||
entry_points={
|
||||
'console_scripts': ['modelscope=modelscope.cli.cli:run_cmd']
|
||||
},
|
||||
dependency_links=deps_link,
|
||||
zip_safe=False)
|
||||
|
||||
0
tests/cli/__init__.py
Normal file
0
tests/cli/__init__.py
Normal file
79
tests/cli/test_download_cmd.py
Normal file
79
tests/cli/test_download_cmd.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import Licenses, ModelVisibility
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
|
||||
DEFAULT_GIT_PATH = 'git'
|
||||
download_model_file_name = 'test.bin'
|
||||
|
||||
|
||||
class DownloadCMDTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
self.api = HubApi()
|
||||
self.api.login(TEST_ACCESS_TOKEN1)
|
||||
self.model_name = 'op-%s' % (uuid.uuid4().hex)
|
||||
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
|
||||
self.revision = 'v0.1_test_revision'
|
||||
self.api.create_model(
|
||||
model_id=self.model_id,
|
||||
visibility=ModelVisibility.PUBLIC,
|
||||
license=Licenses.APACHE_V2,
|
||||
chinese_name=TEST_MODEL_CHINESE_NAME,
|
||||
)
|
||||
self.prepare_case()
|
||||
|
||||
def prepare_case(self):
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
self.model_dir = os.path.join(temporary_dir, self.model_name)
|
||||
repo = Repository(self.model_dir, clone_from=self.model_id)
|
||||
os.system("echo 'testtest'>%s"
|
||||
% os.path.join(self.model_dir, download_model_file_name))
|
||||
repo.push('add model')
|
||||
repo.tag_and_push(self.revision, 'Test revision')
|
||||
|
||||
def tearDown(self):
|
||||
self.api.delete_model(model_id=self.model_id)
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
def test_download(self):
|
||||
cmd = f'python -m modelscope.cli.cli download {self.model_id}'
|
||||
stat, output = subprocess.getstatusoutput(cmd)
|
||||
self.assertEqual(stat, 0)
|
||||
|
||||
def test_download_with_cache(self):
|
||||
cmd = f'python -m modelscope.cli.cli download {self.model_id} --cache_dir {self.tmp_dir}'
|
||||
stat, output = subprocess.getstatusoutput(cmd)
|
||||
if stat != 0:
|
||||
print(output)
|
||||
self.assertEqual(stat, 0)
|
||||
self.assertTrue(
|
||||
osp.exists(
|
||||
f'{self.tmp_dir}/{self.model_id}/{download_model_file_name}'))
|
||||
|
||||
def test_download_with_revision(self):
|
||||
cmd = f'python -m modelscope.cli.cli download {self.model_id} --revision {self.revision}'
|
||||
stat, output = subprocess.getstatusoutput(cmd)
|
||||
if stat != 0:
|
||||
print(output)
|
||||
self.assertEqual(stat, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -13,8 +13,9 @@ from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
|
||||
DEFAULT_GIT_PATH = 'git'
|
||||
|
||||
|
||||
@@ -13,9 +13,10 @@ from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
|
||||
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG,
|
||||
delete_credential)
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_ACCESS_TOKEN2,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG, delete_credential)
|
||||
|
||||
download_model_file_name = 'test.bin'
|
||||
|
||||
|
||||
@@ -9,9 +9,10 @@ from modelscope.hub.constants import Licenses, ModelVisibility
|
||||
from modelscope.hub.errors import GitError
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
|
||||
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG,
|
||||
delete_credential)
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_ACCESS_TOKEN2,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG, delete_credential)
|
||||
|
||||
DEFAULT_GIT_PATH = 'git'
|
||||
|
||||
|
||||
@@ -16,8 +16,9 @@ from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG, delete_credential)
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG, delete_credential)
|
||||
|
||||
logger = get_logger()
|
||||
logger.setLevel('DEBUG')
|
||||
|
||||
@@ -13,8 +13,9 @@ from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
|
||||
logger = get_logger()
|
||||
logger.setLevel('DEBUG')
|
||||
|
||||
@@ -16,8 +16,9 @@ from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
|
||||
logger = get_logger()
|
||||
logger.setLevel('DEBUG')
|
||||
|
||||
@@ -11,8 +11,8 @@ from modelscope.hub.errors import GitError, HTTPError, NotLoginException
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_ORG,
|
||||
delete_credential, test_level)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from codecs import ignore_errors
|
||||
from os.path import expanduser
|
||||
|
||||
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH
|
||||
|
||||
# for user citest and sdkdev
|
||||
TEST_ACCESS_TOKEN1 = os.environ['TEST_ACCESS_TOKEN_CITEST']
|
||||
TEST_ACCESS_TOKEN2 = os.environ['TEST_ACCESS_TOKEN_SDKDEV']
|
||||
|
||||
TEST_MODEL_CHINESE_NAME = '内部测试模型'
|
||||
TEST_MODEL_ORG = 'citest'
|
||||
|
||||
|
||||
def delete_credential():
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
shutil.rmtree(path_credential, ignore_errors=True)
|
||||
Reference in New Issue
Block a user