[to #47860410]plugin with cli tool

1. 支持 plugin方式接入外部 repo、github repo,本地repo,并进行外部插件管理
2. 支持allow_remote方式接入modelhub repo,该类型属于model 范畴不做额外插件管理
3. 支持cli 安装plugin相关

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11775456
This commit is contained in:
zhangzhicheng.zzc
2023-03-09 23:07:13 +08:00
committed by wenmeng.zwm
parent 2b1af959d5
commit 8a19e9645d
13 changed files with 873 additions and 37 deletions

View File

@@ -5,6 +5,7 @@ import argparse
from modelscope.cli.download import DownloadCMD
from modelscope.cli.modelcard import ModelCardCMD
from modelscope.cli.pipeline import PipelineCMD
from modelscope.cli.plugins import PluginsCMD
def run_cmd():
@@ -13,6 +14,7 @@ def run_cmd():
subparsers = parser.add_subparsers(help='modelscope commands helpers')
DownloadCMD.define_args(subparsers)
PluginsCMD.define_args(subparsers)
PipelineCMD.define_args(subparsers)
ModelCardCMD.define_args(subparsers)

118
modelscope/cli/plugins.py Normal file
View File

@@ -0,0 +1,118 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.utils.plugins import PluginsManager
plugins_manager = PluginsManager()
def subparser_func(args):
""" Fuction which will be called for a specific sub parser.
"""
return PluginsCMD(args)
class PluginsCMD(CLICommand):
name = 'plugin'
def __init__(self, args):
self.args = args
@staticmethod
def define_args(parsers: ArgumentParser):
""" define args for install command.
"""
parser = parsers.add_parser(PluginsCMD.name)
subparsers = parser.add_subparsers(dest='command')
PluginsInstallCMD.define_args(subparsers)
PluginsUninstallCMD.define_args(subparsers)
PluginsListCMD.define_args(subparsers)
parser.set_defaults(func=subparser_func)
def execute(self):
print(self.args)
if self.args.command == PluginsInstallCMD.name:
PluginsInstallCMD.execute(self.args)
if self.args.command == PluginsUninstallCMD.name:
PluginsUninstallCMD.execute(self.args)
if self.args.command == PluginsListCMD.name:
PluginsListCMD.execute(self.args)
class PluginsInstallCMD(PluginsCMD):
name = 'install'
@staticmethod
def define_args(parsers: ArgumentParser):
install = parsers.add_parser(PluginsInstallCMD.name)
install.add_argument(
'package',
type=str,
nargs='+',
default=None,
help='Name of the package to be installed.')
install.add_argument(
'--index_url',
'-i',
type=str,
default=None,
help='Base URL of the Python Package Index.')
install.add_argument(
'--force_update',
'-f',
type=str,
default=False,
help='If force update the package')
@staticmethod
def execute(args):
plugins_manager.install_plugins(
list(args.package),
index_url=args.index_url,
force_update=args.force_update)
class PluginsUninstallCMD(PluginsCMD):
name = 'uninstall'
@staticmethod
def define_args(parsers: ArgumentParser):
install = parsers.add_parser(PluginsUninstallCMD.name)
install.add_argument(
'package',
type=str,
nargs='+',
default=None,
help='Name of the package to be installed.')
install.add_argument(
'--yes',
'-y',
type=str,
default=False,
help='Base URL of the Python Package Index.')
@staticmethod
def execute(args):
plugins_manager.uninstall_plugins(list(args.package), is_yes=args.yes)
class PluginsListCMD(PluginsCMD):
name = 'list'
@staticmethod
def define_args(parsers: ArgumentParser):
install = parsers.add_parser(PluginsListCMD.name)
install.add_argument(
'--all',
'-a',
type=str,
default=None,
help='Show all of the plugins including those not installed.')
@staticmethod
def execute(args):
plugins_manager.list_plugins(show_all=all)

View File

@@ -12,6 +12,8 @@ from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
from modelscope.utils.device import verify_device
from modelscope.utils.logger import get_logger
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
logger = get_logger()
@@ -126,6 +128,11 @@ class Model(ABC):
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type
model_cfg.model_dir = local_model_dir
# install and import remote repos before build
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(local_model_dir, cfg.get('allow_remote', False))
for k, v in kwargs.items():
model_cfg[k] = v
if device is not None:

View File

@@ -9,6 +9,8 @@ from modelscope.models.base import Model
from modelscope.utils.config import ConfigDict, check_config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
from modelscope.utils.hub import read_config
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
from .base import Pipeline
from .util import is_official_hub_path
@@ -63,7 +65,6 @@ def pipeline(task: str = None,
framework: str = None,
device: str = 'gpu',
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
plugins: List[str] = None,
**kwargs) -> Pipeline:
""" Factory method to build an obj:`Pipeline`.
@@ -96,8 +97,6 @@ def pipeline(task: str = None,
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
try_import_plugins(plugins)
model = normalize_model_input(model, model_revision)
pipeline_props = {'type': pipeline_name}
if pipeline_name is None:
@@ -111,7 +110,8 @@ def pipeline(task: str = None,
model, str) else read_config(
model[0], revision=model_revision)
check_config(cfg)
try_import_plugins(cfg.safe_get('plugins'))
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model, cfg.get('allow_remote', False))
pipeline_props = cfg.pipeline
elif model is not None:
# get pipeline info from Model object
@@ -120,7 +120,6 @@ def pipeline(task: str = None,
# model is instantiated by user, we should parse config again
cfg = read_config(first_model.model_dir)
check_config(cfg)
try_import_plugins(cfg.safe_get('plugins'))
first_model.pipeline = cfg.pipeline
pipeline_props = first_model.pipeline
else:
@@ -178,10 +177,3 @@ def get_default_pipeline_info(task):
else:
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
return pipeline_name, default_model
def try_import_plugins(plugins: List[str]) -> None:
""" Try to import plugins """
if plugins is not None:
from modelscope.utils.plugins import import_plugins
import_plugins(plugins)

View File

@@ -48,7 +48,7 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline):
>>> input = '这与温岭市新河镇的一个神秘的传说有关。'
>>> print(pipeline_ins(input))
To view other examples plese check the tests/pipelines/test_named_entity_recognition.py.
To view other examples plese check the tests/pipelines/test_plugin_model.py.
"""
super().__init__(
model=model,

View File

@@ -376,6 +376,7 @@ class FilesAstScanning(object):
def __init__(self) -> None:
self.astScaner = AstScanning()
self.file_dirs = []
self.requirement_dirs = []
def _parse_import_path(self,
import_package: str,
@@ -436,15 +437,15 @@ class FilesAstScanning(object):
ignored.add(item)
return list(set(output) - set(ignored))
def traversal_files(self, path, check_sub_dir):
def traversal_files(self, path, check_sub_dir=None):
self.file_dirs = []
if check_sub_dir is None or len(check_sub_dir) == 0:
self._traversal_files(path)
for item in check_sub_dir:
sub_dir = os.path.join(path, item)
if os.path.isdir(sub_dir):
self._traversal_files(sub_dir)
else:
for item in check_sub_dir:
sub_dir = os.path.join(path, item)
if os.path.isdir(sub_dir):
self._traversal_files(sub_dir)
def _traversal_files(self, path):
dir_list = os.scandir(path)
@@ -455,6 +456,8 @@ class FilesAstScanning(object):
self._traversal_files(item.path)
elif item.is_file() and item.name.endswith('.py'):
self.file_dirs.append(item.path)
elif item.is_file() and 'requirement' in item.name:
self.requirement_dirs.append(item.path)
def _get_single_file_scan_result(self, file):
try:

View File

@@ -1,17 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
import copy
import importlib
import os
import pkgutil
import sys
import venv
from contextlib import contextmanager
from fnmatch import fnmatch
from pathlib import Path
from typing import Iterable, List, Optional, Set
from typing import Any, Iterable, List, Optional, Set, Union
import json
import pkg_resources
from modelscope.fileio.file import LocalStorage
from modelscope.utils.ast_utils import FilesAstScanning
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_default_cache_dir
from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.logger import get_logger
logger = get_logger()
storage = LocalStorage()
MODELSCOPE_FILE_DIR = get_default_cache_dir()
PLUGINS_FILENAME = '.modelscope_plugins'
OFFICIAL_PLUGINS = [
{
'name': 'adaseq',
'desc':
'Provide hundreds of additions NERs algorithms, check: https://github.com/modelscope/AdaSeq',
'version': '',
'url': ''
},
]
LOCAL_PLUGINS_FILENAME = '.modelscope_plugins'
GLOBAL_PLUGINS_FILENAME = os.path.join(Path.home(), '.modelscope', 'plugins')
@@ -65,24 +88,36 @@ def discover_file_plugins(
yield module_name
def discover_plugins() -> Iterable[str]:
def discover_plugins(requirement_path=None) -> Iterable[str]:
"""
Discover plugins
Args:
requirement_path: The file path of requirement
"""
plugins: Set[str] = set()
if os.path.isfile(LOCAL_PLUGINS_FILENAME):
with push_python_path('.'):
for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME):
if requirement_path is None:
if os.path.isfile(LOCAL_PLUGINS_FILENAME):
with push_python_path('.'):
for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME):
if plugin in plugins:
continue
yield plugin
plugins.add(plugin)
if os.path.isfile(GLOBAL_PLUGINS_FILENAME):
for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME):
if plugin in plugins:
continue
yield plugin
plugins.add(plugin)
else:
if os.path.isfile(requirement_path):
for plugin in discover_file_plugins(requirement_path):
if plugin in plugins:
continue
yield plugin
plugins.add(plugin)
if os.path.isfile(GLOBAL_PLUGINS_FILENAME):
for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME):
if plugin in plugins:
continue
yield plugin
plugins.add(plugin)
def import_all_plugins(plugins: List[str] = None) -> List[str]:
@@ -142,9 +177,13 @@ def import_plugins(plugins: List[str] = None) -> List[str]:
return imported_plugins
def import_file_plugins() -> List[str]:
def import_file_plugins(requirement_path=None) -> List[str]:
"""
Imports the plugins found with `discover_plugins()`.
Args:
requirement_path: The file path of requirement
"""
imported_plugins: List[str] = []
@@ -153,7 +192,7 @@ def import_file_plugins() -> List[str]:
if cwd not in sys.path:
sys.path.append(cwd)
for module_name in discover_plugins():
for module_name in discover_plugins(requirement_path):
try:
importlib.import_module(module_name)
logger.info('Plugin %s available', module_name)
@@ -174,7 +213,7 @@ def import_module_and_submodules(package_name: str,
include = include if include else set()
exclude = exclude if exclude else set()
def fn_in(packge_name: str, pattern_set: Set[str]) -> bool:
def fn_in(package_name: str, pattern_set: Set[str]) -> bool:
for pattern in pattern_set:
if fnmatch(package_name, pattern):
return True
@@ -213,3 +252,473 @@ def import_module_and_submodules(package_name: str,
logger.warning(f'{package_name} not imported: {str(e)}')
if len(package_name.split('.')) == 1:
raise ModuleNotFoundError('Package not installed')
def install_module_from_requirements(requirement_path, ):
"""
Args:
requirement_path: The path of requirement file
Returns:
"""
install_args = ['-r', requirement_path]
status_code, _, args = PluginsManager.pip_command(
'install',
install_args,
)
if status_code != 0:
raise ImportError(
f'Failed to install requirements from {requirement_path}')
def import_module_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def import_module_from_model_dir(model_dir):
from pathlib import Path
file_scanner = FilesAstScanning()
file_scanner.traversal_files(model_dir)
file_dirs = file_scanner.file_dirs
requirements = file_scanner.requirement_dirs
# install the requirements firstly
install_requirements_by_files(requirements)
# then import the modules
import sys
sys.path.insert(0, model_dir)
for file in file_dirs:
module_name = Path(file).stem
import_module_from_file(module_name, file)
def install_modelscope_if_need():
plugin_installed, version = PluginsManager.check_plugin_installed(
'modelscope')
if not plugin_installed:
status_code, _, args = PluginsManager.pip_command(
'install',
['modelscope'],
)
if status_code != 0:
raise ImportError('Failed to install package modelscope')
def install_requirements_by_names(plugins: List[str]):
plugins_manager = PluginsManager()
uninstalled_plugins = []
for plugin in plugins:
plugin_installed, version = plugins_manager.check_plugin_installed(
plugin)
if not plugin_installed:
uninstalled_plugins.append(plugin)
status, _ = plugins_manager.install_plugins(uninstalled_plugins)
if status != 0:
raise EnvironmentError(
f'The required packages {",".join(uninstalled_plugins)} are not installed.',
f'Please run the command `modelscope plugin install {" ".join(uninstalled_plugins)}` to install them.'
)
install_modelscope_if_need()
def install_requirements_by_files(requirements: List[str]):
for requirement in requirements:
install_module_from_requirements(requirement)
install_modelscope_if_need()
def register_plugins_repo(plugins: List[str]) -> None:
""" Try to install and import plugins from repo"""
if plugins is not None:
install_requirements_by_names(plugins)
import_plugins(plugins)
def register_modelhub_repo(model_dir, allow_remote=False) -> None:
""" Try to install and import remote model from modelhub"""
if allow_remote:
try:
import_module_from_model_dir(model_dir)
except KeyError:
logger.warning(
'Multi component keys in the hub are registered in same file')
pass
class PluginsManager(object):
def __init__(self,
cache_dir=MODELSCOPE_FILE_DIR,
plugins_file=PLUGINS_FILENAME):
cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
plugins_file = os.getenv('MODELSCOPE_PLUGINS_FILE', plugins_file)
self._file_path = os.path.join(cache_dir, plugins_file)
@property
def file_path(self):
return self._file_path
@file_path.setter
def file_path(self, value):
self._file_path = value
@staticmethod
def check_plugin_installed(package):
try:
importlib.reload(pkg_resources)
package_meta_info = pkg_resources.working_set.by_key[package]
version = package_meta_info.version
installed = True
except KeyError:
version = ''
installed = False
return installed, version
@staticmethod
def pip_command(
command,
command_args: List[str],
):
"""
Args:
command: install, uninstall command
command_args: the args to be used with command, should be in list
such as ['-r', 'requirements']
Returns:
"""
from pip._internal.commands import create_command
importlib.reload(pkg_resources)
command = create_command(command)
options, args = command.parse_args(command_args)
status_code = command.main(command_args)
return status_code, options, args
def install_plugins(self,
install_args: List[str],
index_url: Optional[str] = None,
force_update=False) -> Any:
"""Install packages via pip
Args:
install_args (list): List of arguments passed to `pip install`.
index_url (str, optional): The pypi index url.
"""
if len(install_args) == 0:
return 0, []
if index_url is not None:
install_args += ['-i', index_url]
if force_update is not False:
install_args += ['-f']
status_code, options, args = PluginsManager.pip_command(
'install',
install_args,
)
if status_code == 0:
logger.info(f'The plugins {",".join(args)} is installed')
# TODO Add Ast index for ast update record
# Add the plugins info to the local record
installed_package = self.parse_args_info(args, options)
self.update_plugins_file(installed_package)
return status_code, install_args
def parse_args_info(self, args: List[str], options):
installed_package = []
# the case of install with requirements
if len(args) == 0:
src_dir = options.src_dir
requirements = options.requirments
for requirement in requirements:
package_info = {
'name': requirement,
'url': os.path.join(src_dir, requirement),
'desc': '',
'version': ''
}
installed_package.append(package_info)
def get_package_info(package_name):
from pathlib import Path
package_info = {
'name': package_name,
'url': options.index_url,
'desc': ''
}
# the case with git + http
if package_name.split('.')[-1] == 'git':
package_name = Path(package_name).stem
plugin_installed, version = self.check_plugin_installed(
package_name)
if plugin_installed:
package_info['version'] = version
package_info['name'] = package_name
else:
logger.warning(
f'The package {package_name} is not in the lib, this might be happened'
f' when installing the package with git+https method, should be ignored'
)
package_info['version'] = ''
return package_info
for package in args:
package_info = get_package_info(package)
installed_package.append(package_info)
return installed_package
def uninstall_plugins(self,
uninstall_args: Union[str, List],
is_yes=False):
if is_yes is not None:
uninstall_args += ['-y']
status_code, options, args = PluginsManager.pip_command(
'uninstall',
uninstall_args,
)
if status_code == 0:
logger.info(f'The plugins {",".join(args)} is uninstalled')
# TODO Add Ast index for ast update record
# Add to the local record
self.remove_plugins_from_file(args)
return status_code, uninstall_args
def _get_plugins_from_file(self):
""" get plugins from file
"""
logger.info(f'Loading plugins information from {self.file_path}')
if os.path.exists(self.file_path):
local_plugins_info_bytes = storage.read(self.file_path)
local_plugins_info = json.loads(local_plugins_info_bytes)
else:
local_plugins_info = {}
return local_plugins_info
def _update_plugins(
self,
new_plugins_list,
local_plugins_info,
override=False,
):
for item in new_plugins_list:
package_name = item.pop('name')
# update package information if existed
if package_name in local_plugins_info and not override:
original_item = local_plugins_info[package_name]
from pkg_resources import parse_version
item_version = parse_version(
item['version'] if item['version'] != '' else '0.0.0')
origin_version = parse_version(
original_item['version']
if original_item['version'] != '' else '0.0.0')
desc = item['desc']
if original_item['desc'] != '' and desc == '':
desc = original_item['desc']
item = item if item_version > origin_version else original_item
item['desc'] = desc
# Double-check if the item is installed with the version number
if item['version'] == '':
plugin_installed, version = self.check_plugin_installed(
package_name)
item['version'] = version
local_plugins_info[package_name] = item
return local_plugins_info
def _print_plugins_info(self, local_plugins_info):
print('{:<15} |{:<10} |{:<100}'.format('NAME', 'VERSION',
'DESCRIPTION'))
print('')
for k, v in local_plugins_info.items():
print('{:<15} |{:<10} |{:<100}'.format(k, v['version'], v['desc']))
def list_plugins(
self,
show_all=False,
):
"""
Args:
show_all: show installed and official supported if True, else only those installed
Returns:
"""
local_plugins_info = self._get_plugins_from_file()
# update plugins with default
local_official_plugins = copy.deepcopy(OFFICIAL_PLUGINS)
local_plugins_info = self._update_plugins(local_official_plugins,
local_plugins_info)
if show_all is True:
self._print_plugins_info(local_plugins_info)
return local_plugins_info
# Consider those package with version is installed
not_installed_list = []
for item in local_plugins_info:
if local_plugins_info[item]['version'] == '':
not_installed_list.append(item)
for item in not_installed_list:
local_plugins_info.pop(item)
self._print_plugins_info(local_plugins_info)
return local_plugins_info
def update_plugins_file(
self,
plugins_list,
override=False,
):
"""update the plugins file in order to maintain the latest plugins information
Args:
plugins_list: The plugins list contain the information of plugins
name, version, introduction, install url and the status of delete or update
override: Override the file by the list if True, else only update.
Returns:
"""
local_plugins_info = self._get_plugins_from_file()
# local_plugins_info is empty if first time loading, should add OFFICIAL_PLUGINS information
if local_plugins_info == {}:
plugins_list.extend(copy.deepcopy(OFFICIAL_PLUGINS))
local_plugins_info = self._update_plugins(plugins_list,
local_plugins_info, override)
local_plugins_info_json = json.dumps(local_plugins_info)
storage.write(local_plugins_info_json.encode(), self.file_path)
return local_plugins_info_json
def remove_plugins_from_file(
self,
package_names: Union[str, list],
):
"""
Args:
package_names: package name
Returns:
"""
local_plugins_info = self._get_plugins_from_file()
if type(package_names) is str:
package_names = list(package_names)
for item in package_names:
if item in local_plugins_info:
local_plugins_info.pop(item)
local_plugins_info_json = json.dumps(local_plugins_info)
storage.write(local_plugins_info_json.encode(), self.file_path)
return local_plugins_info_json
class EnvsManager(object):
name = 'envs'
def __init__(self,
model_id,
model_revision=DEFAULT_MODEL_REVISION,
cache_dir=MODELSCOPE_FILE_DIR):
"""
Args:
model_id: id of the model, not dir
model_revision: revision of the model, default as master
cache_dir: the system modelscope cache dir
"""
cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
self.env_dir = os.path.join(cache_dir, EnvsManager.name, model_id)
model_dir = snapshot_download(model_id, revision=model_revision)
cfg = read_config(model_dir)
self.plugins = cfg.get('plugins', [])
self.allow_remote = cfg.get('allow_remote', False)
self.env_builder = venv.EnvBuilder(
system_site_packages=True,
clear=False,
symlinks=True,
with_pip=False)
def get_env_dir(self):
return self.env_dir
def get_activate_dir(self):
return os.path.join(self.env_dir, 'bin', 'activate')
def check_if_need_env(self):
if len(self.plugins) or self.allow_remote:
return True
else:
return False
def create_env(self):
if not os.path.exists(self.env_dir):
os.makedirs(self.env_dir)
try:
self.env_builder.create(self.env_dir)
except Exception as e:
self.clean_env()
raise EnvironmentError(
f'Failed to create virtual env at {self.env_dir} with error: {e}'
)
def clean_env(self):
if os.path.exists(self.env_dir):
self.env_builder.clear_directory(self.env_dir)
@staticmethod
def run_process(cmd):
import subprocess
status, result = subprocess.getstatusoutput(cmd)
logger.debug('The status and the results are: {}, {}'.format(
status, result))
if status != 0:
raise Exception(
'running the cmd: {} failed, with message: {}'.format(
cmd, result))
return result
if __name__ == '__main__':
install_requirements_by_files(['adaseq'])

View File

@@ -0,0 +1,50 @@
import subprocess
import unittest
from modelscope.utils.plugins import PluginsManager
class PluginsCMDTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.package = 'adaseq'
self.plugins_manager = PluginsManager()
def tearDown(self):
super().tearDown()
def test_plugins_install(self):
cmd = f'python -m modelscope.cli.cli plugin install {self.package}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
# move this from tear down to avoid unexpected uninstall
uninstall_args = [self.package, '-y']
self.plugins_manager.uninstall_plugins(uninstall_args)
def test_plugins_uninstall(self):
# move this from tear down to avoid unexpected uninstall
uninstall_args = [self.package, '-y']
self.plugins_manager.uninstall_plugins(uninstall_args)
cmd = f'python -m modelscope.cli.cli plugin install {self.package}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
cmd = f'python -m modelscope.cli.cli plugin uninstall {self.package}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
# move this from tear down to avoid unexpected uninstall
uninstall_args = [self.package, '-y']
self.plugins_manager.uninstall_plugins(uninstall_args)
def test_plugins_list(self):
cmd = 'python -m modelscope.cli.cli plugin list'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.pipelines import pipeline
from modelscope.utils.plugins import PluginsManager
from modelscope.utils.test_utils import test_level
class AllowRemoteModelTest(unittest.TestCase):
def setUp(self):
self.package = 'moviepy'
def tearDown(self):
# make sure uninstalled after installing
uninstall_args = [self.package, '-y']
PluginsManager.pip_command('uninstall', uninstall_args)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_bilibili_image(self):
model_path = snapshot_download(
'bilibili/cv_bilibili_image-super-resolution', revision='v1.0.5')
file_path = f'{model_path}/demos/title-compare1.png'
weight_path = f'{model_path}/weights_v3/up2x-latest-denoise3x.pth'
inference = pipeline(
'image-super-resolution',
model='bilibili/cv_bilibili_image-super-resolution',
weight_path=weight_path,
device='cpu',
half=False) # GPU环境可以设置为True
output = inference(file_path, tile_mode=0, cache_mode=1, alpha=1)
print(output)

View File

@@ -4,12 +4,20 @@ import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.plugins import PluginsManager
from modelscope.utils.test_utils import test_level
class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
class PluginModelTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self):
self.package = 'adaseq'
def tearDown(self):
# make sure uninstalled after installing
uninstall_args = [self.package, '-y']
PluginsManager.pip_command('uninstall', uninstall_args)
super().tearDown()
import subprocess
result = subprocess.run(
['pip', 'install', 'adaseq>=0.6.2', '--no-deps'],

25
tests/utils/test_envs.py Normal file
View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.utils.plugins import EnvsManager
class PluginTest(unittest.TestCase):
def setUp(self):
self.model_id = 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med'
self.env_manager = EnvsManager(self.model_id)
def tearDown(self):
self.env_manager.clean_env()
super().tearDown()
def test_create_env(self):
need_env = self.env_manager.check_if_need_env()
self.assertEqual(need_env, True)
activate_dir = self.env_manager.create_env()
remote = 'source {}'.format(activate_dir)
cmd = f'{remote};'
print(cmd)
# EnvsManager.run_process(cmd) no sh in ci env, so skip

View File

@@ -1,16 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
from modelscope.models.builder import MODELS
from modelscope.utils.plugins import (discover_plugins, import_all_plugins,
import_file_plugins, import_plugins,
pushd)
from modelscope.utils.plugins import (PluginsManager, discover_plugins,
import_all_plugins, import_file_plugins,
import_plugins, pushd)
from modelscope.utils.test_utils import test_level
class PluginTest(unittest.TestCase):
def setUp(self):
self.plugins_root = 'tests/utils/plugins/'
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.package = 'adaseq'
self.plugins_manager = PluginsManager()
def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()
def test_no_plugins(self):
available_plugins = set(discover_plugins())
@@ -39,3 +52,75 @@ class PluginTest(unittest.TestCase):
import_all_plugins()
assert MODELS.get('dummy-model', 'dummy-group') is not None
def test_install_plugins(self):
"""
examples for the modelscope install method
> modelscope install adaseq ofasys
> modelscope install git+https://github.com/modelscope/AdaSeq.git
> modelscope install adaseq -i <url> -f <url>
> modelscope install adaseq --extra-index-url <url> --trusted-host <hostname>
"""
install_args = [self.package]
status_code, install_args = self.plugins_manager.install_plugins(
install_args)
self.assertEqual(status_code, 0)
install_args = ['random_blabla']
status_code, install_args = self.plugins_manager.install_plugins(
install_args)
self.assertEqual(status_code, 1)
install_args = [self.package, 'random_blabla']
status_code, install_args = self.plugins_manager.install_plugins(
install_args)
self.assertEqual(status_code, 1)
# move this from tear down to avoid unexpected uninstall
uninstall_args = [self.package, '-y']
self.plugins_manager.uninstall_plugins(uninstall_args)
@unittest.skip
def test_install_plugins_with_git(self):
install_args = ['git+https://github.com/modelscope/AdaSeq.git']
status_code, install_args = self.plugins_manager.install_plugins(
install_args)
self.assertEqual(status_code, 0)
# move this from tear down to avoid unexpected uninstall
uninstall_args = ['git+https://github.com/modelscope/AdaSeq.git', '-y']
self.plugins_manager.uninstall_plugins(uninstall_args)
def test_uninstall_plugins(self):
"""
examples for the modelscope uninstall method
> modelscope uninstall adaseq
> modelscope uninstall -y adaseq
"""
install_args = [self.package]
status_code, install_args = self.plugins_manager.install_plugins(
install_args)
self.assertEqual(status_code, 0)
uninstall_args = [self.package, '-y']
status_code, uninstall_args = self.plugins_manager.uninstall_plugins(
uninstall_args)
self.assertEqual(status_code, 0)
def test_list_plugins(self):
"""
examples for the modelscope list method
> modelscope list
> modelscope list --all
> modelscope list -a
# """
modelscope_plugin = os.path.join(self.tmp_dir, 'modelscope_plugin')
self.plugins_manager.file_path = modelscope_plugin
result = self.plugins_manager.list_plugins()
self.assertEqual(len(result.items()), 0)
from modelscope.utils.plugins import OFFICIAL_PLUGINS
result = self.plugins_manager.list_plugins(show_all=True)
self.assertEqual(len(result.items()), len(OFFICIAL_PLUGINS))