diff --git a/modelscope/cli/cli.py b/modelscope/cli/cli.py index 47c39b99..a25502fd 100644 --- a/modelscope/cli/cli.py +++ b/modelscope/cli/cli.py @@ -5,6 +5,7 @@ import argparse from modelscope.cli.download import DownloadCMD from modelscope.cli.modelcard import ModelCardCMD from modelscope.cli.pipeline import PipelineCMD +from modelscope.cli.plugins import PluginsCMD def run_cmd(): @@ -13,6 +14,7 @@ def run_cmd(): subparsers = parser.add_subparsers(help='modelscope commands helpers') DownloadCMD.define_args(subparsers) + PluginsCMD.define_args(subparsers) PipelineCMD.define_args(subparsers) ModelCardCMD.define_args(subparsers) diff --git a/modelscope/cli/plugins.py b/modelscope/cli/plugins.py new file mode 100644 index 00000000..e40457df --- /dev/null +++ b/modelscope/cli/plugins.py @@ -0,0 +1,118 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from argparse import ArgumentParser + +from modelscope.cli.base import CLICommand +from modelscope.utils.plugins import PluginsManager + +plugins_manager = PluginsManager() + + +def subparser_func(args): + """ Fuction which will be called for a specific sub parser. + """ + return PluginsCMD(args) + + +class PluginsCMD(CLICommand): + name = 'plugin' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: ArgumentParser): + """ define args for install command. + """ + parser = parsers.add_parser(PluginsCMD.name) + subparsers = parser.add_subparsers(dest='command') + + PluginsInstallCMD.define_args(subparsers) + PluginsUninstallCMD.define_args(subparsers) + PluginsListCMD.define_args(subparsers) + + parser.set_defaults(func=subparser_func) + + def execute(self): + print(self.args) + if self.args.command == PluginsInstallCMD.name: + PluginsInstallCMD.execute(self.args) + if self.args.command == PluginsUninstallCMD.name: + PluginsUninstallCMD.execute(self.args) + if self.args.command == PluginsListCMD.name: + PluginsListCMD.execute(self.args) + + +class PluginsInstallCMD(PluginsCMD): + name = 'install' + + @staticmethod + def define_args(parsers: ArgumentParser): + install = parsers.add_parser(PluginsInstallCMD.name) + install.add_argument( + 'package', + type=str, + nargs='+', + default=None, + help='Name of the package to be installed.') + install.add_argument( + '--index_url', + '-i', + type=str, + default=None, + help='Base URL of the Python Package Index.') + install.add_argument( + '--force_update', + '-f', + type=str, + default=False, + help='If force update the package') + + @staticmethod + def execute(args): + plugins_manager.install_plugins( + list(args.package), + index_url=args.index_url, + force_update=args.force_update) + + +class PluginsUninstallCMD(PluginsCMD): + name = 'uninstall' + + @staticmethod + def define_args(parsers: ArgumentParser): + install = parsers.add_parser(PluginsUninstallCMD.name) + install.add_argument( + 'package', + type=str, + nargs='+', + default=None, + help='Name of the package to be installed.') + install.add_argument( + '--yes', + '-y', + type=str, + default=False, + help='Base URL of the Python Package Index.') + + @staticmethod + def execute(args): + plugins_manager.uninstall_plugins(list(args.package), is_yes=args.yes) + + +class PluginsListCMD(PluginsCMD): + name = 'list' + + @staticmethod + def define_args(parsers: ArgumentParser): + install = parsers.add_parser(PluginsListCMD.name) + install.add_argument( + '--all', + '-a', + type=str, + default=None, + help='Show all of the plugins including those not installed.') + + @staticmethod + def execute(args): + plugins_manager.list_plugins(show_all=all) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 18855829..0edb740e 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -12,6 +12,8 @@ from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile from modelscope.utils.device import verify_device from modelscope.utils.logger import get_logger +from modelscope.utils.plugins import (register_modelhub_repo, + register_plugins_repo) logger = get_logger() @@ -126,6 +128,11 @@ class Model(ABC): if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type model_cfg.model_dir = local_model_dir + + # install and import remote repos before build + register_plugins_repo(cfg.safe_get('plugins')) + register_modelhub_repo(local_model_dir, cfg.get('allow_remote', False)) + for k, v in kwargs.items(): model_cfg[k] = v if device is not None: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 4987a3e0..dd39453c 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -9,6 +9,8 @@ from modelscope.models.base import Model from modelscope.utils.config import ConfigDict, check_config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke from modelscope.utils.hub import read_config +from modelscope.utils.plugins import (register_modelhub_repo, + register_plugins_repo) from modelscope.utils.registry import Registry, build_from_cfg from .base import Pipeline from .util import is_official_hub_path @@ -63,7 +65,6 @@ def pipeline(task: str = None, framework: str = None, device: str = 'gpu', model_revision: Optional[str] = DEFAULT_MODEL_REVISION, - plugins: List[str] = None, **kwargs) -> Pipeline: """ Factory method to build an obj:`Pipeline`. @@ -96,8 +97,6 @@ def pipeline(task: str = None, if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') - try_import_plugins(plugins) - model = normalize_model_input(model, model_revision) pipeline_props = {'type': pipeline_name} if pipeline_name is None: @@ -111,7 +110,8 @@ def pipeline(task: str = None, model, str) else read_config( model[0], revision=model_revision) check_config(cfg) - try_import_plugins(cfg.safe_get('plugins')) + register_plugins_repo(cfg.safe_get('plugins')) + register_modelhub_repo(model, cfg.get('allow_remote', False)) pipeline_props = cfg.pipeline elif model is not None: # get pipeline info from Model object @@ -120,7 +120,6 @@ def pipeline(task: str = None, # model is instantiated by user, we should parse config again cfg = read_config(first_model.model_dir) check_config(cfg) - try_import_plugins(cfg.safe_get('plugins')) first_model.pipeline = cfg.pipeline pipeline_props = first_model.pipeline else: @@ -178,10 +177,3 @@ def get_default_pipeline_info(task): else: pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] return pipeline_name, default_model - - -def try_import_plugins(plugins: List[str]) -> None: - """ Try to import plugins """ - if plugins is not None: - from modelscope.utils.plugins import import_plugins - import_plugins(plugins) diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py index 8a25c415..ba174bae 100644 --- a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -48,7 +48,7 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline): >>> input = '这与温岭市新河镇的一个神秘的传说有关。' >>> print(pipeline_ins(input)) - To view other examples plese check the tests/pipelines/test_named_entity_recognition.py. + To view other examples plese check the tests/pipelines/test_plugin_model.py. """ super().__init__( model=model, diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py index 50b4277e..7f078467 100644 --- a/modelscope/utils/ast_utils.py +++ b/modelscope/utils/ast_utils.py @@ -376,6 +376,7 @@ class FilesAstScanning(object): def __init__(self) -> None: self.astScaner = AstScanning() self.file_dirs = [] + self.requirement_dirs = [] def _parse_import_path(self, import_package: str, @@ -436,15 +437,15 @@ class FilesAstScanning(object): ignored.add(item) return list(set(output) - set(ignored)) - def traversal_files(self, path, check_sub_dir): + def traversal_files(self, path, check_sub_dir=None): self.file_dirs = [] if check_sub_dir is None or len(check_sub_dir) == 0: self._traversal_files(path) - - for item in check_sub_dir: - sub_dir = os.path.join(path, item) - if os.path.isdir(sub_dir): - self._traversal_files(sub_dir) + else: + for item in check_sub_dir: + sub_dir = os.path.join(path, item) + if os.path.isdir(sub_dir): + self._traversal_files(sub_dir) def _traversal_files(self, path): dir_list = os.scandir(path) @@ -455,6 +456,8 @@ class FilesAstScanning(object): self._traversal_files(item.path) elif item.is_file() and item.name.endswith('.py'): self.file_dirs.append(item.path) + elif item.is_file() and 'requirement' in item.name: + self.requirement_dirs.append(item.path) def _get_single_file_scan_result(self, file): try: diff --git a/modelscope/utils/plugins.py b/modelscope/utils/plugins.py index 6c2f2975..e62f775d 100644 --- a/modelscope/utils/plugins.py +++ b/modelscope/utils/plugins.py @@ -1,17 +1,40 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +import copy import importlib import os import pkgutil import sys +import venv from contextlib import contextmanager from fnmatch import fnmatch from pathlib import Path -from typing import Iterable, List, Optional, Set +from typing import Any, Iterable, List, Optional, Set, Union +import json +import pkg_resources + +from modelscope.fileio.file import LocalStorage +from modelscope.utils.ast_utils import FilesAstScanning +from modelscope.utils.constant import DEFAULT_MODEL_REVISION +from modelscope.utils.file_utils import get_default_cache_dir +from modelscope.utils.hub import read_config, snapshot_download from modelscope.utils.logger import get_logger logger = get_logger() +storage = LocalStorage() + +MODELSCOPE_FILE_DIR = get_default_cache_dir() +PLUGINS_FILENAME = '.modelscope_plugins' +OFFICIAL_PLUGINS = [ + { + 'name': 'adaseq', + 'desc': + 'Provide hundreds of additions NERs algorithms, check: https://github.com/modelscope/AdaSeq', + 'version': '', + 'url': '' + }, +] LOCAL_PLUGINS_FILENAME = '.modelscope_plugins' GLOBAL_PLUGINS_FILENAME = os.path.join(Path.home(), '.modelscope', 'plugins') @@ -65,24 +88,36 @@ def discover_file_plugins( yield module_name -def discover_plugins() -> Iterable[str]: +def discover_plugins(requirement_path=None) -> Iterable[str]: """ Discover plugins + + Args: + requirement_path: The file path of requirement + """ plugins: Set[str] = set() - if os.path.isfile(LOCAL_PLUGINS_FILENAME): - with push_python_path('.'): - for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME): + if requirement_path is None: + if os.path.isfile(LOCAL_PLUGINS_FILENAME): + with push_python_path('.'): + for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME): + if plugin in plugins: + continue + yield plugin + plugins.add(plugin) + if os.path.isfile(GLOBAL_PLUGINS_FILENAME): + for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME): + if plugin in plugins: + continue + yield plugin + plugins.add(plugin) + else: + if os.path.isfile(requirement_path): + for plugin in discover_file_plugins(requirement_path): if plugin in plugins: continue yield plugin plugins.add(plugin) - if os.path.isfile(GLOBAL_PLUGINS_FILENAME): - for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME): - if plugin in plugins: - continue - yield plugin - plugins.add(plugin) def import_all_plugins(plugins: List[str] = None) -> List[str]: @@ -142,9 +177,13 @@ def import_plugins(plugins: List[str] = None) -> List[str]: return imported_plugins -def import_file_plugins() -> List[str]: +def import_file_plugins(requirement_path=None) -> List[str]: """ Imports the plugins found with `discover_plugins()`. + + Args: + requirement_path: The file path of requirement + """ imported_plugins: List[str] = [] @@ -153,7 +192,7 @@ def import_file_plugins() -> List[str]: if cwd not in sys.path: sys.path.append(cwd) - for module_name in discover_plugins(): + for module_name in discover_plugins(requirement_path): try: importlib.import_module(module_name) logger.info('Plugin %s available', module_name) @@ -174,7 +213,7 @@ def import_module_and_submodules(package_name: str, include = include if include else set() exclude = exclude if exclude else set() - def fn_in(packge_name: str, pattern_set: Set[str]) -> bool: + def fn_in(package_name: str, pattern_set: Set[str]) -> bool: for pattern in pattern_set: if fnmatch(package_name, pattern): return True @@ -213,3 +252,473 @@ def import_module_and_submodules(package_name: str, logger.warning(f'{package_name} not imported: {str(e)}') if len(package_name.split('.')) == 1: raise ModuleNotFoundError('Package not installed') + + +def install_module_from_requirements(requirement_path, ): + """ + Args: + requirement_path: The path of requirement file + + Returns: + + """ + + install_args = ['-r', requirement_path] + status_code, _, args = PluginsManager.pip_command( + 'install', + install_args, + ) + if status_code != 0: + raise ImportError( + f'Failed to install requirements from {requirement_path}') + + +def import_module_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def import_module_from_model_dir(model_dir): + from pathlib import Path + file_scanner = FilesAstScanning() + file_scanner.traversal_files(model_dir) + file_dirs = file_scanner.file_dirs + requirements = file_scanner.requirement_dirs + + # install the requirements firstly + install_requirements_by_files(requirements) + + # then import the modules + import sys + sys.path.insert(0, model_dir) + for file in file_dirs: + module_name = Path(file).stem + import_module_from_file(module_name, file) + + +def install_modelscope_if_need(): + plugin_installed, version = PluginsManager.check_plugin_installed( + 'modelscope') + if not plugin_installed: + status_code, _, args = PluginsManager.pip_command( + 'install', + ['modelscope'], + ) + if status_code != 0: + raise ImportError('Failed to install package modelscope') + + +def install_requirements_by_names(plugins: List[str]): + plugins_manager = PluginsManager() + uninstalled_plugins = [] + for plugin in plugins: + plugin_installed, version = plugins_manager.check_plugin_installed( + plugin) + if not plugin_installed: + uninstalled_plugins.append(plugin) + status, _ = plugins_manager.install_plugins(uninstalled_plugins) + if status != 0: + raise EnvironmentError( + f'The required packages {",".join(uninstalled_plugins)} are not installed.', + f'Please run the command `modelscope plugin install {" ".join(uninstalled_plugins)}` to install them.' + ) + install_modelscope_if_need() + + +def install_requirements_by_files(requirements: List[str]): + for requirement in requirements: + install_module_from_requirements(requirement) + install_modelscope_if_need() + + +def register_plugins_repo(plugins: List[str]) -> None: + """ Try to install and import plugins from repo""" + if plugins is not None: + install_requirements_by_names(plugins) + import_plugins(plugins) + + +def register_modelhub_repo(model_dir, allow_remote=False) -> None: + """ Try to install and import remote model from modelhub""" + if allow_remote: + try: + import_module_from_model_dir(model_dir) + except KeyError: + logger.warning( + 'Multi component keys in the hub are registered in same file') + pass + + +class PluginsManager(object): + + def __init__(self, + cache_dir=MODELSCOPE_FILE_DIR, + plugins_file=PLUGINS_FILENAME): + cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir) + plugins_file = os.getenv('MODELSCOPE_PLUGINS_FILE', plugins_file) + self._file_path = os.path.join(cache_dir, plugins_file) + + @property + def file_path(self): + return self._file_path + + @file_path.setter + def file_path(self, value): + self._file_path = value + + @staticmethod + def check_plugin_installed(package): + try: + importlib.reload(pkg_resources) + package_meta_info = pkg_resources.working_set.by_key[package] + version = package_meta_info.version + installed = True + except KeyError: + version = '' + installed = False + + return installed, version + + @staticmethod + def pip_command( + command, + command_args: List[str], + ): + """ + + Args: + command: install, uninstall command + command_args: the args to be used with command, should be in list + such as ['-r', 'requirements'] + + Returns: + + """ + from pip._internal.commands import create_command + importlib.reload(pkg_resources) + command = create_command(command) + options, args = command.parse_args(command_args) + + status_code = command.main(command_args) + return status_code, options, args + + def install_plugins(self, + install_args: List[str], + index_url: Optional[str] = None, + force_update=False) -> Any: + """Install packages via pip + Args: + install_args (list): List of arguments passed to `pip install`. + index_url (str, optional): The pypi index url. + """ + + if len(install_args) == 0: + return 0, [] + + if index_url is not None: + install_args += ['-i', index_url] + + if force_update is not False: + install_args += ['-f'] + + status_code, options, args = PluginsManager.pip_command( + 'install', + install_args, + ) + + if status_code == 0: + logger.info(f'The plugins {",".join(args)} is installed') + + # TODO Add Ast index for ast update record + + # Add the plugins info to the local record + installed_package = self.parse_args_info(args, options) + self.update_plugins_file(installed_package) + + return status_code, install_args + + def parse_args_info(self, args: List[str], options): + installed_package = [] + + # the case of install with requirements + if len(args) == 0: + src_dir = options.src_dir + requirements = options.requirments + for requirement in requirements: + package_info = { + 'name': requirement, + 'url': os.path.join(src_dir, requirement), + 'desc': '', + 'version': '' + } + + installed_package.append(package_info) + + def get_package_info(package_name): + from pathlib import Path + package_info = { + 'name': package_name, + 'url': options.index_url, + 'desc': '' + } + + # the case with git + http + if package_name.split('.')[-1] == 'git': + package_name = Path(package_name).stem + + plugin_installed, version = self.check_plugin_installed( + package_name) + if plugin_installed: + package_info['version'] = version + package_info['name'] = package_name + else: + logger.warning( + f'The package {package_name} is not in the lib, this might be happened' + f' when installing the package with git+https method, should be ignored' + ) + package_info['version'] = '' + + return package_info + + for package in args: + package_info = get_package_info(package) + installed_package.append(package_info) + + return installed_package + + def uninstall_plugins(self, + uninstall_args: Union[str, List], + is_yes=False): + if is_yes is not None: + uninstall_args += ['-y'] + + status_code, options, args = PluginsManager.pip_command( + 'uninstall', + uninstall_args, + ) + + if status_code == 0: + logger.info(f'The plugins {",".join(args)} is uninstalled') + + # TODO Add Ast index for ast update record + + # Add to the local record + self.remove_plugins_from_file(args) + + return status_code, uninstall_args + + def _get_plugins_from_file(self): + """ get plugins from file + + """ + logger.info(f'Loading plugins information from {self.file_path}') + if os.path.exists(self.file_path): + local_plugins_info_bytes = storage.read(self.file_path) + local_plugins_info = json.loads(local_plugins_info_bytes) + else: + local_plugins_info = {} + return local_plugins_info + + def _update_plugins( + self, + new_plugins_list, + local_plugins_info, + override=False, + ): + for item in new_plugins_list: + package_name = item.pop('name') + + # update package information if existed + if package_name in local_plugins_info and not override: + original_item = local_plugins_info[package_name] + from pkg_resources import parse_version + item_version = parse_version( + item['version'] if item['version'] != '' else '0.0.0') + origin_version = parse_version( + original_item['version'] + if original_item['version'] != '' else '0.0.0') + desc = item['desc'] + if original_item['desc'] != '' and desc == '': + desc = original_item['desc'] + item = item if item_version > origin_version else original_item + item['desc'] = desc + + # Double-check if the item is installed with the version number + if item['version'] == '': + plugin_installed, version = self.check_plugin_installed( + package_name) + item['version'] = version + + local_plugins_info[package_name] = item + + return local_plugins_info + + def _print_plugins_info(self, local_plugins_info): + print('{:<15} |{:<10} |{:<100}'.format('NAME', 'VERSION', + 'DESCRIPTION')) + print('') + for k, v in local_plugins_info.items(): + print('{:<15} |{:<10} |{:<100}'.format(k, v['version'], v['desc'])) + + def list_plugins( + self, + show_all=False, + ): + """ + + Args: + show_all: show installed and official supported if True, else only those installed + + Returns: + + """ + local_plugins_info = self._get_plugins_from_file() + + # update plugins with default + + local_official_plugins = copy.deepcopy(OFFICIAL_PLUGINS) + local_plugins_info = self._update_plugins(local_official_plugins, + local_plugins_info) + + if show_all is True: + self._print_plugins_info(local_plugins_info) + return local_plugins_info + + # Consider those package with version is installed + not_installed_list = [] + for item in local_plugins_info: + if local_plugins_info[item]['version'] == '': + not_installed_list.append(item) + + for item in not_installed_list: + local_plugins_info.pop(item) + + self._print_plugins_info(local_plugins_info) + return local_plugins_info + + def update_plugins_file( + self, + plugins_list, + override=False, + ): + """update the plugins file in order to maintain the latest plugins information + + Args: + plugins_list: The plugins list contain the information of plugins + name, version, introduction, install url and the status of delete or update + override: Override the file by the list if True, else only update. + + Returns: + + """ + local_plugins_info = self._get_plugins_from_file() + + # local_plugins_info is empty if first time loading, should add OFFICIAL_PLUGINS information + if local_plugins_info == {}: + plugins_list.extend(copy.deepcopy(OFFICIAL_PLUGINS)) + + local_plugins_info = self._update_plugins(plugins_list, + local_plugins_info, override) + + local_plugins_info_json = json.dumps(local_plugins_info) + storage.write(local_plugins_info_json.encode(), self.file_path) + + return local_plugins_info_json + + def remove_plugins_from_file( + self, + package_names: Union[str, list], + ): + """ + + Args: + package_names: package name + + Returns: + + """ + local_plugins_info = self._get_plugins_from_file() + + if type(package_names) is str: + package_names = list(package_names) + + for item in package_names: + if item in local_plugins_info: + local_plugins_info.pop(item) + + local_plugins_info_json = json.dumps(local_plugins_info) + storage.write(local_plugins_info_json.encode(), self.file_path) + + return local_plugins_info_json + + +class EnvsManager(object): + name = 'envs' + + def __init__(self, + model_id, + model_revision=DEFAULT_MODEL_REVISION, + cache_dir=MODELSCOPE_FILE_DIR): + """ + + Args: + model_id: id of the model, not dir + model_revision: revision of the model, default as master + cache_dir: the system modelscope cache dir + """ + cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir) + self.env_dir = os.path.join(cache_dir, EnvsManager.name, model_id) + model_dir = snapshot_download(model_id, revision=model_revision) + cfg = read_config(model_dir) + self.plugins = cfg.get('plugins', []) + self.allow_remote = cfg.get('allow_remote', False) + self.env_builder = venv.EnvBuilder( + system_site_packages=True, + clear=False, + symlinks=True, + with_pip=False) + + def get_env_dir(self): + return self.env_dir + + def get_activate_dir(self): + return os.path.join(self.env_dir, 'bin', 'activate') + + def check_if_need_env(self): + if len(self.plugins) or self.allow_remote: + return True + else: + return False + + def create_env(self): + if not os.path.exists(self.env_dir): + os.makedirs(self.env_dir) + try: + self.env_builder.create(self.env_dir) + except Exception as e: + self.clean_env() + raise EnvironmentError( + f'Failed to create virtual env at {self.env_dir} with error: {e}' + ) + + def clean_env(self): + if os.path.exists(self.env_dir): + self.env_builder.clear_directory(self.env_dir) + + @staticmethod + def run_process(cmd): + import subprocess + status, result = subprocess.getstatusoutput(cmd) + logger.debug('The status and the results are: {}, {}'.format( + status, result)) + if status != 0: + raise Exception( + 'running the cmd: {} failed, with message: {}'.format( + cmd, result)) + return result + + +if __name__ == '__main__': + install_requirements_by_files(['adaseq']) diff --git a/tests/cli/test_plugins_cmd.py b/tests/cli/test_plugins_cmd.py new file mode 100644 index 00000000..b11c67ab --- /dev/null +++ b/tests/cli/test_plugins_cmd.py @@ -0,0 +1,50 @@ +import subprocess +import unittest + +from modelscope.utils.plugins import PluginsManager + + +class PluginsCMDTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.package = 'adaseq' + self.plugins_manager = PluginsManager() + + def tearDown(self): + super().tearDown() + + def test_plugins_install(self): + cmd = f'python -m modelscope.cli.cli plugin install {self.package}' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + def test_plugins_uninstall(self): + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + cmd = f'python -m modelscope.cli.cli plugin install {self.package}' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + cmd = f'python -m modelscope.cli.cli plugin uninstall {self.package}' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + def test_plugins_list(self): + cmd = 'python -m modelscope.cli.cli plugin list' + stat, output = subprocess.getstatusoutput(cmd) + self.assertEqual(stat, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/adaseq_pipelines/__init__.py b/tests/pipelines/plugin_remote_pipelines/__init__.py similarity index 100% rename from tests/pipelines/adaseq_pipelines/__init__.py rename to tests/pipelines/plugin_remote_pipelines/__init__.py diff --git a/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py b/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py new file mode 100644 index 00000000..0453cf64 --- /dev/null +++ b/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.plugins import PluginsManager +from modelscope.utils.test_utils import test_level + + +class AllowRemoteModelTest(unittest.TestCase): + + def setUp(self): + self.package = 'moviepy' + + def tearDown(self): + # make sure uninstalled after installing + uninstall_args = [self.package, '-y'] + PluginsManager.pip_command('uninstall', uninstall_args) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_bilibili_image(self): + + model_path = snapshot_download( + 'bilibili/cv_bilibili_image-super-resolution', revision='v1.0.5') + file_path = f'{model_path}/demos/title-compare1.png' + weight_path = f'{model_path}/weights_v3/up2x-latest-denoise3x.pth' + inference = pipeline( + 'image-super-resolution', + model='bilibili/cv_bilibili_image-super-resolution', + weight_path=weight_path, + device='cpu', + half=False) # GPU环境可以设置为True + + output = inference(file_path, tile_mode=0, cache_mode=1, alpha=1) + print(output) diff --git a/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py b/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py similarity index 82% rename from tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py rename to tests/pipelines/plugin_remote_pipelines/test_plugin_model.py index 4ddc3131..40124dac 100644 --- a/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py @@ -4,12 +4,20 @@ import unittest from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.plugins import PluginsManager from modelscope.utils.test_utils import test_level -class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): +class PluginModelTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self): + self.package = 'adaseq' + + def tearDown(self): + # make sure uninstalled after installing + uninstall_args = [self.package, '-y'] + PluginsManager.pip_command('uninstall', uninstall_args) + super().tearDown() import subprocess result = subprocess.run( ['pip', 'install', 'adaseq>=0.6.2', '--no-deps'], diff --git a/tests/utils/test_envs.py b/tests/utils/test_envs.py new file mode 100644 index 00000000..e87297ac --- /dev/null +++ b/tests/utils/test_envs.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.utils.plugins import EnvsManager + + +class PluginTest(unittest.TestCase): + + def setUp(self): + self.model_id = 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med' + self.env_manager = EnvsManager(self.model_id) + + def tearDown(self): + self.env_manager.clean_env() + super().tearDown() + + def test_create_env(self): + need_env = self.env_manager.check_if_need_env() + self.assertEqual(need_env, True) + activate_dir = self.env_manager.create_env() + remote = 'source {}'.format(activate_dir) + cmd = f'{remote};' + print(cmd) + # EnvsManager.run_process(cmd) no sh in ci env, so skip diff --git a/tests/utils/test_plugin.py b/tests/utils/test_plugin.py index 40d86f9d..447ce1c9 100644 --- a/tests/utils/test_plugin.py +++ b/tests/utils/test_plugin.py @@ -1,16 +1,29 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile import unittest from modelscope.models.builder import MODELS -from modelscope.utils.plugins import (discover_plugins, import_all_plugins, - import_file_plugins, import_plugins, - pushd) +from modelscope.utils.plugins import (PluginsManager, discover_plugins, + import_all_plugins, import_file_plugins, + import_plugins, pushd) +from modelscope.utils.test_utils import test_level class PluginTest(unittest.TestCase): def setUp(self): self.plugins_root = 'tests/utils/plugins/' + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.package = 'adaseq' + self.plugins_manager = PluginsManager() + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() def test_no_plugins(self): available_plugins = set(discover_plugins()) @@ -39,3 +52,75 @@ class PluginTest(unittest.TestCase): import_all_plugins() assert MODELS.get('dummy-model', 'dummy-group') is not None + + def test_install_plugins(self): + """ + examples for the modelscope install method + > modelscope install adaseq ofasys + > modelscope install git+https://github.com/modelscope/AdaSeq.git + > modelscope install adaseq -i -f + > modelscope install adaseq --extra-index-url --trusted-host + """ + install_args = [self.package] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 0) + + install_args = ['random_blabla'] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 1) + + install_args = [self.package, 'random_blabla'] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 1) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = [self.package, '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + @unittest.skip + def test_install_plugins_with_git(self): + + install_args = ['git+https://github.com/modelscope/AdaSeq.git'] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 0) + + # move this from tear down to avoid unexpected uninstall + uninstall_args = ['git+https://github.com/modelscope/AdaSeq.git', '-y'] + self.plugins_manager.uninstall_plugins(uninstall_args) + + def test_uninstall_plugins(self): + """ + examples for the modelscope uninstall method + > modelscope uninstall adaseq + > modelscope uninstall -y adaseq + """ + install_args = [self.package] + status_code, install_args = self.plugins_manager.install_plugins( + install_args) + self.assertEqual(status_code, 0) + + uninstall_args = [self.package, '-y'] + status_code, uninstall_args = self.plugins_manager.uninstall_plugins( + uninstall_args) + self.assertEqual(status_code, 0) + + def test_list_plugins(self): + """ + examples for the modelscope list method + > modelscope list + > modelscope list --all + > modelscope list -a + # """ + modelscope_plugin = os.path.join(self.tmp_dir, 'modelscope_plugin') + self.plugins_manager.file_path = modelscope_plugin + result = self.plugins_manager.list_plugins() + self.assertEqual(len(result.items()), 0) + + from modelscope.utils.plugins import OFFICIAL_PLUGINS + + result = self.plugins_manager.list_plugins(show_all=True) + self.assertEqual(len(result.items()), len(OFFICIAL_PLUGINS))