add command line tool

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11661608

* add command line tool

* add unittest

* change absolute import to relative import for test case

* mv test_util to package
This commit is contained in:
wenmeng.zwm
2023-02-20 21:21:54 +08:00
parent 35d43e9945
commit 285208912b
16 changed files with 208 additions and 36 deletions

View File

20
modelscope/cli/base.py Normal file
View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABC, abstractmethod
from argparse import ArgumentParser
class CLICommand(ABC):
"""
Base class for command line tool.
"""
@staticmethod
@abstractmethod
def define_args(parsers: ArgumentParser):
raise NotImplementedError()
@abstractmethod
def execute(self):
raise NotImplementedError()

26
modelscope/cli/cli.py Normal file
View File

@@ -0,0 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
from modelscope.cli.download import DownloadCMD
def run_cmd():
parser = argparse.ArgumentParser(
'ModelScope Command Line tool', usage='modelscope <command> [<args>]')
subparsers = parser.add_subparsers(help='modelscope commands helpers')
DownloadCMD.define_args(subparsers)
args = parser.parse_args()
if not hasattr(args, 'func'):
parser.print_help()
exit(1)
cmd = args.func(args)
cmd.execute()
if __name__ == '__main__':
run_cmd()

View File

@@ -0,0 +1,44 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.snapshot_download import snapshot_download
def subparser_func(args):
""" Fuction which will be called for a specific sub parser.
"""
return DownloadCMD(args)
class DownloadCMD(CLICommand):
name = 'download'
def __init__(self, args):
self.args = args
@staticmethod
def define_args(parsers: ArgumentParser):
""" define args for download command.
"""
parser = parsers.add_parser(DownloadCMD.name)
parser.add_argument(
'model', type=str, help='Name of the model to be downloaded.')
parser.add_argument(
'--revision',
type=str,
default=None,
help='Revision of the model.')
parser.add_argument(
'--cache_dir',
type=str,
default=None,
help='Cache directory to save model.')
parser.set_defaults(func=subparser_func)
def execute(self):
snapshot_download(
self.args.model,
cache_dir=self.args.cache_dir,
revision=self.args.revision)

View File

@@ -13,15 +13,29 @@ import tempfile
import unittest
from collections import OrderedDict
from collections.abc import Mapping
from os.path import expanduser
import numpy as np
import requests
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH
from modelscope.utils.import_utils import is_tf_available, is_torch_available
TEST_LEVEL = 2
TEST_LEVEL_STR = 'TEST_LEVEL'
# for user citest and sdkdev
TEST_ACCESS_TOKEN1 = os.environ['TEST_ACCESS_TOKEN_CITEST']
TEST_ACCESS_TOKEN2 = os.environ['TEST_ACCESS_TOKEN_SDKDEV']
TEST_MODEL_CHINESE_NAME = '内部测试模型'
TEST_MODEL_ORG = 'citest'
def delete_credential():
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
shutil.rmtree(path_credential, ignore_errors=True)
def test_level():
global TEST_LEVEL

View File

@@ -223,5 +223,8 @@ if __name__ == '__main__':
tests_require=parse_requirements('requirements/tests.txt'),
install_requires=install_requires,
extras_require=extra_requires,
entry_points={
'console_scripts': ['modelscope=modelscope.cli.cli:run_cmd']
},
dependency_links=deps_link,
zip_safe=False)

0
tests/cli/__init__.py Normal file
View File

View File

@@ -0,0 +1,79 @@
import os
import os.path as osp
import shutil
import subprocess
import tempfile
import unittest
import uuid
from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.repository import Repository
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
DEFAULT_GIT_PATH = 'git'
download_model_file_name = 'test.bin'
class DownloadCMDTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.api = HubApi()
self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = 'op-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME,
)
self.prepare_case()
def prepare_case(self):
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name))
repo.push('add model')
repo.tag_and_push(self.revision, 'Test revision')
def tearDown(self):
self.api.delete_model(model_id=self.model_id)
shutil.rmtree(self.tmp_dir)
super().tearDown()
def test_download(self):
cmd = f'python -m modelscope.cli.cli download {self.model_id}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
def test_download_with_cache(self):
cmd = f'python -m modelscope.cli.cli download {self.model_id} --cache_dir {self.tmp_dir}'
stat, output = subprocess.getstatusoutput(cmd)
if stat != 0:
print(output)
self.assertEqual(stat, 0)
self.assertTrue(
osp.exists(
f'{self.tmp_dir}/{self.model_id}/{download_model_file_name}'))
def test_download_with_revision(self):
cmd = f'python -m modelscope.cli.cli download {self.model_id} --revision {self.revision}'
stat, output = subprocess.getstatusoutput(cmd)
if stat != 0:
print(output)
self.assertEqual(stat, 0)
if __name__ == '__main__':
unittest.main()

View File

@@ -13,8 +13,9 @@ from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
DEFAULT_GIT_PATH = 'git'

View File

@@ -13,9 +13,10 @@ from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG,
delete_credential)
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_ACCESS_TOKEN2,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG, delete_credential)
download_model_file_name = 'test.bin'

View File

@@ -9,9 +9,10 @@ from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError
from modelscope.hub.repository import Repository
from modelscope.utils.constant import ModelFile
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
TEST_MODEL_CHINESE_NAME, TEST_MODEL_ORG,
delete_credential)
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_ACCESS_TOKEN2,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG, delete_credential)
DEFAULT_GIT_PATH = 'git'

View File

@@ -16,8 +16,9 @@ from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.repository import Repository
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG, delete_credential)
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG, delete_credential)
logger = get_logger()
logger.setLevel('DEBUG')

View File

@@ -13,8 +13,9 @@ from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
logger = get_logger()
logger.setLevel('DEBUG')

View File

@@ -16,8 +16,9 @@ from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.logger import get_logger
from .test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
logger = get_logger()
logger.setLevel('DEBUG')

View File

@@ -11,8 +11,8 @@ from modelscope.hub.errors import GitError, HTTPError, NotLoginException
from modelscope.hub.repository import Repository
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level
from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_ORG,
delete_credential, test_level)
logger = get_logger()

View File

@@ -1,20 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
from codecs import ignore_errors
from os.path import expanduser
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH
# for user citest and sdkdev
TEST_ACCESS_TOKEN1 = os.environ['TEST_ACCESS_TOKEN_CITEST']
TEST_ACCESS_TOKEN2 = os.environ['TEST_ACCESS_TOKEN_SDKDEV']
TEST_MODEL_CHINESE_NAME = '内部测试模型'
TEST_MODEL_ORG = 'citest'
def delete_credential():
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
shutil.rmtree(path_credential, ignore_errors=True)