mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #47860410]plugin with cli tool
1. 支持 plugin方式接入外部 repo、github repo,本地repo,并进行外部插件管理 2. 支持allow_remote方式接入modelhub repo,该类型属于model 范畴不做额外插件管理 3. 支持cli 安装plugin相关 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11775456
This commit is contained in:
committed by
wenmeng.zwm
parent
2b1af959d5
commit
8a19e9645d
@@ -5,6 +5,7 @@ import argparse
|
||||
from modelscope.cli.download import DownloadCMD
|
||||
from modelscope.cli.modelcard import ModelCardCMD
|
||||
from modelscope.cli.pipeline import PipelineCMD
|
||||
from modelscope.cli.plugins import PluginsCMD
|
||||
|
||||
|
||||
def run_cmd():
|
||||
@@ -13,6 +14,7 @@ def run_cmd():
|
||||
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
||||
|
||||
DownloadCMD.define_args(subparsers)
|
||||
PluginsCMD.define_args(subparsers)
|
||||
PipelineCMD.define_args(subparsers)
|
||||
ModelCardCMD.define_args(subparsers)
|
||||
|
||||
|
||||
118
modelscope/cli/plugins.py
Normal file
118
modelscope/cli/plugins.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.utils.plugins import PluginsManager
|
||||
|
||||
plugins_manager = PluginsManager()
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
""" Fuction which will be called for a specific sub parser.
|
||||
"""
|
||||
return PluginsCMD(args)
|
||||
|
||||
|
||||
class PluginsCMD(CLICommand):
|
||||
name = 'plugin'
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
""" define args for install command.
|
||||
"""
|
||||
parser = parsers.add_parser(PluginsCMD.name)
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
PluginsInstallCMD.define_args(subparsers)
|
||||
PluginsUninstallCMD.define_args(subparsers)
|
||||
PluginsListCMD.define_args(subparsers)
|
||||
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
print(self.args)
|
||||
if self.args.command == PluginsInstallCMD.name:
|
||||
PluginsInstallCMD.execute(self.args)
|
||||
if self.args.command == PluginsUninstallCMD.name:
|
||||
PluginsUninstallCMD.execute(self.args)
|
||||
if self.args.command == PluginsListCMD.name:
|
||||
PluginsListCMD.execute(self.args)
|
||||
|
||||
|
||||
class PluginsInstallCMD(PluginsCMD):
|
||||
name = 'install'
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
install = parsers.add_parser(PluginsInstallCMD.name)
|
||||
install.add_argument(
|
||||
'package',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Name of the package to be installed.')
|
||||
install.add_argument(
|
||||
'--index_url',
|
||||
'-i',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Base URL of the Python Package Index.')
|
||||
install.add_argument(
|
||||
'--force_update',
|
||||
'-f',
|
||||
type=str,
|
||||
default=False,
|
||||
help='If force update the package')
|
||||
|
||||
@staticmethod
|
||||
def execute(args):
|
||||
plugins_manager.install_plugins(
|
||||
list(args.package),
|
||||
index_url=args.index_url,
|
||||
force_update=args.force_update)
|
||||
|
||||
|
||||
class PluginsUninstallCMD(PluginsCMD):
|
||||
name = 'uninstall'
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
install = parsers.add_parser(PluginsUninstallCMD.name)
|
||||
install.add_argument(
|
||||
'package',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Name of the package to be installed.')
|
||||
install.add_argument(
|
||||
'--yes',
|
||||
'-y',
|
||||
type=str,
|
||||
default=False,
|
||||
help='Base URL of the Python Package Index.')
|
||||
|
||||
@staticmethod
|
||||
def execute(args):
|
||||
plugins_manager.uninstall_plugins(list(args.package), is_yes=args.yes)
|
||||
|
||||
|
||||
class PluginsListCMD(PluginsCMD):
|
||||
name = 'list'
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
install = parsers.add_parser(PluginsListCMD.name)
|
||||
install.add_argument(
|
||||
'--all',
|
||||
'-a',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Show all of the plugins including those not installed.')
|
||||
|
||||
@staticmethod
|
||||
def execute(args):
|
||||
plugins_manager.list_plugins(show_all=all)
|
||||
@@ -12,6 +12,8 @@ from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.device import verify_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -126,6 +128,11 @@ class Model(ABC):
|
||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
model_cfg.model_dir = local_model_dir
|
||||
|
||||
# install and import remote repos before build
|
||||
register_plugins_repo(cfg.safe_get('plugins'))
|
||||
register_modelhub_repo(local_model_dir, cfg.get('allow_remote', False))
|
||||
|
||||
for k, v in kwargs.items():
|
||||
model_cfg[k] = v
|
||||
if device is not None:
|
||||
|
||||
@@ -9,6 +9,8 @@ from modelscope.models.base import Model
|
||||
from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
from .base import Pipeline
|
||||
from .util import is_official_hub_path
|
||||
@@ -63,7 +65,6 @@ def pipeline(task: str = None,
|
||||
framework: str = None,
|
||||
device: str = 'gpu',
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
plugins: List[str] = None,
|
||||
**kwargs) -> Pipeline:
|
||||
""" Factory method to build an obj:`Pipeline`.
|
||||
|
||||
@@ -96,8 +97,6 @@ def pipeline(task: str = None,
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
|
||||
try_import_plugins(plugins)
|
||||
|
||||
model = normalize_model_input(model, model_revision)
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
if pipeline_name is None:
|
||||
@@ -111,7 +110,8 @@ def pipeline(task: str = None,
|
||||
model, str) else read_config(
|
||||
model[0], revision=model_revision)
|
||||
check_config(cfg)
|
||||
try_import_plugins(cfg.safe_get('plugins'))
|
||||
register_plugins_repo(cfg.safe_get('plugins'))
|
||||
register_modelhub_repo(model, cfg.get('allow_remote', False))
|
||||
pipeline_props = cfg.pipeline
|
||||
elif model is not None:
|
||||
# get pipeline info from Model object
|
||||
@@ -120,7 +120,6 @@ def pipeline(task: str = None,
|
||||
# model is instantiated by user, we should parse config again
|
||||
cfg = read_config(first_model.model_dir)
|
||||
check_config(cfg)
|
||||
try_import_plugins(cfg.safe_get('plugins'))
|
||||
first_model.pipeline = cfg.pipeline
|
||||
pipeline_props = first_model.pipeline
|
||||
else:
|
||||
@@ -178,10 +177,3 @@ def get_default_pipeline_info(task):
|
||||
else:
|
||||
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
|
||||
return pipeline_name, default_model
|
||||
|
||||
|
||||
def try_import_plugins(plugins: List[str]) -> None:
|
||||
""" Try to import plugins """
|
||||
if plugins is not None:
|
||||
from modelscope.utils.plugins import import_plugins
|
||||
import_plugins(plugins)
|
||||
|
||||
@@ -48,7 +48,7 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline):
|
||||
>>> input = '这与温岭市新河镇的一个神秘的传说有关。'
|
||||
>>> print(pipeline_ins(input))
|
||||
|
||||
To view other examples plese check the tests/pipelines/test_named_entity_recognition.py.
|
||||
To view other examples plese check the tests/pipelines/test_plugin_model.py.
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
||||
@@ -376,6 +376,7 @@ class FilesAstScanning(object):
|
||||
def __init__(self) -> None:
|
||||
self.astScaner = AstScanning()
|
||||
self.file_dirs = []
|
||||
self.requirement_dirs = []
|
||||
|
||||
def _parse_import_path(self,
|
||||
import_package: str,
|
||||
@@ -436,15 +437,15 @@ class FilesAstScanning(object):
|
||||
ignored.add(item)
|
||||
return list(set(output) - set(ignored))
|
||||
|
||||
def traversal_files(self, path, check_sub_dir):
|
||||
def traversal_files(self, path, check_sub_dir=None):
|
||||
self.file_dirs = []
|
||||
if check_sub_dir is None or len(check_sub_dir) == 0:
|
||||
self._traversal_files(path)
|
||||
|
||||
for item in check_sub_dir:
|
||||
sub_dir = os.path.join(path, item)
|
||||
if os.path.isdir(sub_dir):
|
||||
self._traversal_files(sub_dir)
|
||||
else:
|
||||
for item in check_sub_dir:
|
||||
sub_dir = os.path.join(path, item)
|
||||
if os.path.isdir(sub_dir):
|
||||
self._traversal_files(sub_dir)
|
||||
|
||||
def _traversal_files(self, path):
|
||||
dir_list = os.scandir(path)
|
||||
@@ -455,6 +456,8 @@ class FilesAstScanning(object):
|
||||
self._traversal_files(item.path)
|
||||
elif item.is_file() and item.name.endswith('.py'):
|
||||
self.file_dirs.append(item.path)
|
||||
elif item.is_file() and 'requirement' in item.name:
|
||||
self.requirement_dirs.append(item.path)
|
||||
|
||||
def _get_single_file_scan_result(self, file):
|
||||
try:
|
||||
|
||||
@@ -1,17 +1,40 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
import venv
|
||||
from contextlib import contextmanager
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Set
|
||||
from typing import Any, Iterable, List, Optional, Set, Union
|
||||
|
||||
import json
|
||||
import pkg_resources
|
||||
|
||||
from modelscope.fileio.file import LocalStorage
|
||||
from modelscope.utils.ast_utils import FilesAstScanning
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_default_cache_dir
|
||||
from modelscope.utils.hub import read_config, snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
storage = LocalStorage()
|
||||
|
||||
MODELSCOPE_FILE_DIR = get_default_cache_dir()
|
||||
PLUGINS_FILENAME = '.modelscope_plugins'
|
||||
OFFICIAL_PLUGINS = [
|
||||
{
|
||||
'name': 'adaseq',
|
||||
'desc':
|
||||
'Provide hundreds of additions NERs algorithms, check: https://github.com/modelscope/AdaSeq',
|
||||
'version': '',
|
||||
'url': ''
|
||||
},
|
||||
]
|
||||
|
||||
LOCAL_PLUGINS_FILENAME = '.modelscope_plugins'
|
||||
GLOBAL_PLUGINS_FILENAME = os.path.join(Path.home(), '.modelscope', 'plugins')
|
||||
@@ -65,24 +88,36 @@ def discover_file_plugins(
|
||||
yield module_name
|
||||
|
||||
|
||||
def discover_plugins() -> Iterable[str]:
|
||||
def discover_plugins(requirement_path=None) -> Iterable[str]:
|
||||
"""
|
||||
Discover plugins
|
||||
|
||||
Args:
|
||||
requirement_path: The file path of requirement
|
||||
|
||||
"""
|
||||
plugins: Set[str] = set()
|
||||
if os.path.isfile(LOCAL_PLUGINS_FILENAME):
|
||||
with push_python_path('.'):
|
||||
for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME):
|
||||
if requirement_path is None:
|
||||
if os.path.isfile(LOCAL_PLUGINS_FILENAME):
|
||||
with push_python_path('.'):
|
||||
for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME):
|
||||
if plugin in plugins:
|
||||
continue
|
||||
yield plugin
|
||||
plugins.add(plugin)
|
||||
if os.path.isfile(GLOBAL_PLUGINS_FILENAME):
|
||||
for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME):
|
||||
if plugin in plugins:
|
||||
continue
|
||||
yield plugin
|
||||
plugins.add(plugin)
|
||||
else:
|
||||
if os.path.isfile(requirement_path):
|
||||
for plugin in discover_file_plugins(requirement_path):
|
||||
if plugin in plugins:
|
||||
continue
|
||||
yield plugin
|
||||
plugins.add(plugin)
|
||||
if os.path.isfile(GLOBAL_PLUGINS_FILENAME):
|
||||
for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME):
|
||||
if plugin in plugins:
|
||||
continue
|
||||
yield plugin
|
||||
plugins.add(plugin)
|
||||
|
||||
|
||||
def import_all_plugins(plugins: List[str] = None) -> List[str]:
|
||||
@@ -142,9 +177,13 @@ def import_plugins(plugins: List[str] = None) -> List[str]:
|
||||
return imported_plugins
|
||||
|
||||
|
||||
def import_file_plugins() -> List[str]:
|
||||
def import_file_plugins(requirement_path=None) -> List[str]:
|
||||
"""
|
||||
Imports the plugins found with `discover_plugins()`.
|
||||
|
||||
Args:
|
||||
requirement_path: The file path of requirement
|
||||
|
||||
"""
|
||||
imported_plugins: List[str] = []
|
||||
|
||||
@@ -153,7 +192,7 @@ def import_file_plugins() -> List[str]:
|
||||
if cwd not in sys.path:
|
||||
sys.path.append(cwd)
|
||||
|
||||
for module_name in discover_plugins():
|
||||
for module_name in discover_plugins(requirement_path):
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
logger.info('Plugin %s available', module_name)
|
||||
@@ -174,7 +213,7 @@ def import_module_and_submodules(package_name: str,
|
||||
include = include if include else set()
|
||||
exclude = exclude if exclude else set()
|
||||
|
||||
def fn_in(packge_name: str, pattern_set: Set[str]) -> bool:
|
||||
def fn_in(package_name: str, pattern_set: Set[str]) -> bool:
|
||||
for pattern in pattern_set:
|
||||
if fnmatch(package_name, pattern):
|
||||
return True
|
||||
@@ -213,3 +252,473 @@ def import_module_and_submodules(package_name: str,
|
||||
logger.warning(f'{package_name} not imported: {str(e)}')
|
||||
if len(package_name.split('.')) == 1:
|
||||
raise ModuleNotFoundError('Package not installed')
|
||||
|
||||
|
||||
def install_module_from_requirements(requirement_path, ):
|
||||
"""
|
||||
Args:
|
||||
requirement_path: The path of requirement file
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
install_args = ['-r', requirement_path]
|
||||
status_code, _, args = PluginsManager.pip_command(
|
||||
'install',
|
||||
install_args,
|
||||
)
|
||||
if status_code != 0:
|
||||
raise ImportError(
|
||||
f'Failed to install requirements from {requirement_path}')
|
||||
|
||||
|
||||
def import_module_from_file(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def import_module_from_model_dir(model_dir):
|
||||
from pathlib import Path
|
||||
file_scanner = FilesAstScanning()
|
||||
file_scanner.traversal_files(model_dir)
|
||||
file_dirs = file_scanner.file_dirs
|
||||
requirements = file_scanner.requirement_dirs
|
||||
|
||||
# install the requirements firstly
|
||||
install_requirements_by_files(requirements)
|
||||
|
||||
# then import the modules
|
||||
import sys
|
||||
sys.path.insert(0, model_dir)
|
||||
for file in file_dirs:
|
||||
module_name = Path(file).stem
|
||||
import_module_from_file(module_name, file)
|
||||
|
||||
|
||||
def install_modelscope_if_need():
|
||||
plugin_installed, version = PluginsManager.check_plugin_installed(
|
||||
'modelscope')
|
||||
if not plugin_installed:
|
||||
status_code, _, args = PluginsManager.pip_command(
|
||||
'install',
|
||||
['modelscope'],
|
||||
)
|
||||
if status_code != 0:
|
||||
raise ImportError('Failed to install package modelscope')
|
||||
|
||||
|
||||
def install_requirements_by_names(plugins: List[str]):
|
||||
plugins_manager = PluginsManager()
|
||||
uninstalled_plugins = []
|
||||
for plugin in plugins:
|
||||
plugin_installed, version = plugins_manager.check_plugin_installed(
|
||||
plugin)
|
||||
if not plugin_installed:
|
||||
uninstalled_plugins.append(plugin)
|
||||
status, _ = plugins_manager.install_plugins(uninstalled_plugins)
|
||||
if status != 0:
|
||||
raise EnvironmentError(
|
||||
f'The required packages {",".join(uninstalled_plugins)} are not installed.',
|
||||
f'Please run the command `modelscope plugin install {" ".join(uninstalled_plugins)}` to install them.'
|
||||
)
|
||||
install_modelscope_if_need()
|
||||
|
||||
|
||||
def install_requirements_by_files(requirements: List[str]):
|
||||
for requirement in requirements:
|
||||
install_module_from_requirements(requirement)
|
||||
install_modelscope_if_need()
|
||||
|
||||
|
||||
def register_plugins_repo(plugins: List[str]) -> None:
|
||||
""" Try to install and import plugins from repo"""
|
||||
if plugins is not None:
|
||||
install_requirements_by_names(plugins)
|
||||
import_plugins(plugins)
|
||||
|
||||
|
||||
def register_modelhub_repo(model_dir, allow_remote=False) -> None:
|
||||
""" Try to install and import remote model from modelhub"""
|
||||
if allow_remote:
|
||||
try:
|
||||
import_module_from_model_dir(model_dir)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
'Multi component keys in the hub are registered in same file')
|
||||
pass
|
||||
|
||||
|
||||
class PluginsManager(object):
|
||||
|
||||
def __init__(self,
|
||||
cache_dir=MODELSCOPE_FILE_DIR,
|
||||
plugins_file=PLUGINS_FILENAME):
|
||||
cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
|
||||
plugins_file = os.getenv('MODELSCOPE_PLUGINS_FILE', plugins_file)
|
||||
self._file_path = os.path.join(cache_dir, plugins_file)
|
||||
|
||||
@property
|
||||
def file_path(self):
|
||||
return self._file_path
|
||||
|
||||
@file_path.setter
|
||||
def file_path(self, value):
|
||||
self._file_path = value
|
||||
|
||||
@staticmethod
|
||||
def check_plugin_installed(package):
|
||||
try:
|
||||
importlib.reload(pkg_resources)
|
||||
package_meta_info = pkg_resources.working_set.by_key[package]
|
||||
version = package_meta_info.version
|
||||
installed = True
|
||||
except KeyError:
|
||||
version = ''
|
||||
installed = False
|
||||
|
||||
return installed, version
|
||||
|
||||
@staticmethod
|
||||
def pip_command(
|
||||
command,
|
||||
command_args: List[str],
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
command: install, uninstall command
|
||||
command_args: the args to be used with command, should be in list
|
||||
such as ['-r', 'requirements']
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
from pip._internal.commands import create_command
|
||||
importlib.reload(pkg_resources)
|
||||
command = create_command(command)
|
||||
options, args = command.parse_args(command_args)
|
||||
|
||||
status_code = command.main(command_args)
|
||||
return status_code, options, args
|
||||
|
||||
def install_plugins(self,
|
||||
install_args: List[str],
|
||||
index_url: Optional[str] = None,
|
||||
force_update=False) -> Any:
|
||||
"""Install packages via pip
|
||||
Args:
|
||||
install_args (list): List of arguments passed to `pip install`.
|
||||
index_url (str, optional): The pypi index url.
|
||||
"""
|
||||
|
||||
if len(install_args) == 0:
|
||||
return 0, []
|
||||
|
||||
if index_url is not None:
|
||||
install_args += ['-i', index_url]
|
||||
|
||||
if force_update is not False:
|
||||
install_args += ['-f']
|
||||
|
||||
status_code, options, args = PluginsManager.pip_command(
|
||||
'install',
|
||||
install_args,
|
||||
)
|
||||
|
||||
if status_code == 0:
|
||||
logger.info(f'The plugins {",".join(args)} is installed')
|
||||
|
||||
# TODO Add Ast index for ast update record
|
||||
|
||||
# Add the plugins info to the local record
|
||||
installed_package = self.parse_args_info(args, options)
|
||||
self.update_plugins_file(installed_package)
|
||||
|
||||
return status_code, install_args
|
||||
|
||||
def parse_args_info(self, args: List[str], options):
|
||||
installed_package = []
|
||||
|
||||
# the case of install with requirements
|
||||
if len(args) == 0:
|
||||
src_dir = options.src_dir
|
||||
requirements = options.requirments
|
||||
for requirement in requirements:
|
||||
package_info = {
|
||||
'name': requirement,
|
||||
'url': os.path.join(src_dir, requirement),
|
||||
'desc': '',
|
||||
'version': ''
|
||||
}
|
||||
|
||||
installed_package.append(package_info)
|
||||
|
||||
def get_package_info(package_name):
|
||||
from pathlib import Path
|
||||
package_info = {
|
||||
'name': package_name,
|
||||
'url': options.index_url,
|
||||
'desc': ''
|
||||
}
|
||||
|
||||
# the case with git + http
|
||||
if package_name.split('.')[-1] == 'git':
|
||||
package_name = Path(package_name).stem
|
||||
|
||||
plugin_installed, version = self.check_plugin_installed(
|
||||
package_name)
|
||||
if plugin_installed:
|
||||
package_info['version'] = version
|
||||
package_info['name'] = package_name
|
||||
else:
|
||||
logger.warning(
|
||||
f'The package {package_name} is not in the lib, this might be happened'
|
||||
f' when installing the package with git+https method, should be ignored'
|
||||
)
|
||||
package_info['version'] = ''
|
||||
|
||||
return package_info
|
||||
|
||||
for package in args:
|
||||
package_info = get_package_info(package)
|
||||
installed_package.append(package_info)
|
||||
|
||||
return installed_package
|
||||
|
||||
def uninstall_plugins(self,
|
||||
uninstall_args: Union[str, List],
|
||||
is_yes=False):
|
||||
if is_yes is not None:
|
||||
uninstall_args += ['-y']
|
||||
|
||||
status_code, options, args = PluginsManager.pip_command(
|
||||
'uninstall',
|
||||
uninstall_args,
|
||||
)
|
||||
|
||||
if status_code == 0:
|
||||
logger.info(f'The plugins {",".join(args)} is uninstalled')
|
||||
|
||||
# TODO Add Ast index for ast update record
|
||||
|
||||
# Add to the local record
|
||||
self.remove_plugins_from_file(args)
|
||||
|
||||
return status_code, uninstall_args
|
||||
|
||||
def _get_plugins_from_file(self):
|
||||
""" get plugins from file
|
||||
|
||||
"""
|
||||
logger.info(f'Loading plugins information from {self.file_path}')
|
||||
if os.path.exists(self.file_path):
|
||||
local_plugins_info_bytes = storage.read(self.file_path)
|
||||
local_plugins_info = json.loads(local_plugins_info_bytes)
|
||||
else:
|
||||
local_plugins_info = {}
|
||||
return local_plugins_info
|
||||
|
||||
def _update_plugins(
|
||||
self,
|
||||
new_plugins_list,
|
||||
local_plugins_info,
|
||||
override=False,
|
||||
):
|
||||
for item in new_plugins_list:
|
||||
package_name = item.pop('name')
|
||||
|
||||
# update package information if existed
|
||||
if package_name in local_plugins_info and not override:
|
||||
original_item = local_plugins_info[package_name]
|
||||
from pkg_resources import parse_version
|
||||
item_version = parse_version(
|
||||
item['version'] if item['version'] != '' else '0.0.0')
|
||||
origin_version = parse_version(
|
||||
original_item['version']
|
||||
if original_item['version'] != '' else '0.0.0')
|
||||
desc = item['desc']
|
||||
if original_item['desc'] != '' and desc == '':
|
||||
desc = original_item['desc']
|
||||
item = item if item_version > origin_version else original_item
|
||||
item['desc'] = desc
|
||||
|
||||
# Double-check if the item is installed with the version number
|
||||
if item['version'] == '':
|
||||
plugin_installed, version = self.check_plugin_installed(
|
||||
package_name)
|
||||
item['version'] = version
|
||||
|
||||
local_plugins_info[package_name] = item
|
||||
|
||||
return local_plugins_info
|
||||
|
||||
def _print_plugins_info(self, local_plugins_info):
|
||||
print('{:<15} |{:<10} |{:<100}'.format('NAME', 'VERSION',
|
||||
'DESCRIPTION'))
|
||||
print('')
|
||||
for k, v in local_plugins_info.items():
|
||||
print('{:<15} |{:<10} |{:<100}'.format(k, v['version'], v['desc']))
|
||||
|
||||
def list_plugins(
|
||||
self,
|
||||
show_all=False,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
show_all: show installed and official supported if True, else only those installed
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
local_plugins_info = self._get_plugins_from_file()
|
||||
|
||||
# update plugins with default
|
||||
|
||||
local_official_plugins = copy.deepcopy(OFFICIAL_PLUGINS)
|
||||
local_plugins_info = self._update_plugins(local_official_plugins,
|
||||
local_plugins_info)
|
||||
|
||||
if show_all is True:
|
||||
self._print_plugins_info(local_plugins_info)
|
||||
return local_plugins_info
|
||||
|
||||
# Consider those package with version is installed
|
||||
not_installed_list = []
|
||||
for item in local_plugins_info:
|
||||
if local_plugins_info[item]['version'] == '':
|
||||
not_installed_list.append(item)
|
||||
|
||||
for item in not_installed_list:
|
||||
local_plugins_info.pop(item)
|
||||
|
||||
self._print_plugins_info(local_plugins_info)
|
||||
return local_plugins_info
|
||||
|
||||
def update_plugins_file(
|
||||
self,
|
||||
plugins_list,
|
||||
override=False,
|
||||
):
|
||||
"""update the plugins file in order to maintain the latest plugins information
|
||||
|
||||
Args:
|
||||
plugins_list: The plugins list contain the information of plugins
|
||||
name, version, introduction, install url and the status of delete or update
|
||||
override: Override the file by the list if True, else only update.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
local_plugins_info = self._get_plugins_from_file()
|
||||
|
||||
# local_plugins_info is empty if first time loading, should add OFFICIAL_PLUGINS information
|
||||
if local_plugins_info == {}:
|
||||
plugins_list.extend(copy.deepcopy(OFFICIAL_PLUGINS))
|
||||
|
||||
local_plugins_info = self._update_plugins(plugins_list,
|
||||
local_plugins_info, override)
|
||||
|
||||
local_plugins_info_json = json.dumps(local_plugins_info)
|
||||
storage.write(local_plugins_info_json.encode(), self.file_path)
|
||||
|
||||
return local_plugins_info_json
|
||||
|
||||
def remove_plugins_from_file(
|
||||
self,
|
||||
package_names: Union[str, list],
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
package_names: package name
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
local_plugins_info = self._get_plugins_from_file()
|
||||
|
||||
if type(package_names) is str:
|
||||
package_names = list(package_names)
|
||||
|
||||
for item in package_names:
|
||||
if item in local_plugins_info:
|
||||
local_plugins_info.pop(item)
|
||||
|
||||
local_plugins_info_json = json.dumps(local_plugins_info)
|
||||
storage.write(local_plugins_info_json.encode(), self.file_path)
|
||||
|
||||
return local_plugins_info_json
|
||||
|
||||
|
||||
class EnvsManager(object):
|
||||
name = 'envs'
|
||||
|
||||
def __init__(self,
|
||||
model_id,
|
||||
model_revision=DEFAULT_MODEL_REVISION,
|
||||
cache_dir=MODELSCOPE_FILE_DIR):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model_id: id of the model, not dir
|
||||
model_revision: revision of the model, default as master
|
||||
cache_dir: the system modelscope cache dir
|
||||
"""
|
||||
cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
|
||||
self.env_dir = os.path.join(cache_dir, EnvsManager.name, model_id)
|
||||
model_dir = snapshot_download(model_id, revision=model_revision)
|
||||
cfg = read_config(model_dir)
|
||||
self.plugins = cfg.get('plugins', [])
|
||||
self.allow_remote = cfg.get('allow_remote', False)
|
||||
self.env_builder = venv.EnvBuilder(
|
||||
system_site_packages=True,
|
||||
clear=False,
|
||||
symlinks=True,
|
||||
with_pip=False)
|
||||
|
||||
def get_env_dir(self):
|
||||
return self.env_dir
|
||||
|
||||
def get_activate_dir(self):
|
||||
return os.path.join(self.env_dir, 'bin', 'activate')
|
||||
|
||||
def check_if_need_env(self):
|
||||
if len(self.plugins) or self.allow_remote:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def create_env(self):
|
||||
if not os.path.exists(self.env_dir):
|
||||
os.makedirs(self.env_dir)
|
||||
try:
|
||||
self.env_builder.create(self.env_dir)
|
||||
except Exception as e:
|
||||
self.clean_env()
|
||||
raise EnvironmentError(
|
||||
f'Failed to create virtual env at {self.env_dir} with error: {e}'
|
||||
)
|
||||
|
||||
def clean_env(self):
|
||||
if os.path.exists(self.env_dir):
|
||||
self.env_builder.clear_directory(self.env_dir)
|
||||
|
||||
@staticmethod
|
||||
def run_process(cmd):
|
||||
import subprocess
|
||||
status, result = subprocess.getstatusoutput(cmd)
|
||||
logger.debug('The status and the results are: {}, {}'.format(
|
||||
status, result))
|
||||
if status != 0:
|
||||
raise Exception(
|
||||
'running the cmd: {} failed, with message: {}'.format(
|
||||
cmd, result))
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
install_requirements_by_files(['adaseq'])
|
||||
|
||||
50
tests/cli/test_plugins_cmd.py
Normal file
50
tests/cli/test_plugins_cmd.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import subprocess
|
||||
import unittest
|
||||
|
||||
from modelscope.utils.plugins import PluginsManager
|
||||
|
||||
|
||||
class PluginsCMDTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.package = 'adaseq'
|
||||
self.plugins_manager = PluginsManager()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
def test_plugins_install(self):
|
||||
cmd = f'python -m modelscope.cli.cli plugin install {self.package}'
|
||||
stat, output = subprocess.getstatusoutput(cmd)
|
||||
self.assertEqual(stat, 0)
|
||||
|
||||
# move this from tear down to avoid unexpected uninstall
|
||||
uninstall_args = [self.package, '-y']
|
||||
self.plugins_manager.uninstall_plugins(uninstall_args)
|
||||
|
||||
def test_plugins_uninstall(self):
|
||||
# move this from tear down to avoid unexpected uninstall
|
||||
uninstall_args = [self.package, '-y']
|
||||
self.plugins_manager.uninstall_plugins(uninstall_args)
|
||||
|
||||
cmd = f'python -m modelscope.cli.cli plugin install {self.package}'
|
||||
stat, output = subprocess.getstatusoutput(cmd)
|
||||
self.assertEqual(stat, 0)
|
||||
|
||||
cmd = f'python -m modelscope.cli.cli plugin uninstall {self.package}'
|
||||
stat, output = subprocess.getstatusoutput(cmd)
|
||||
self.assertEqual(stat, 0)
|
||||
|
||||
# move this from tear down to avoid unexpected uninstall
|
||||
uninstall_args = [self.package, '-y']
|
||||
self.plugins_manager.uninstall_plugins(uninstall_args)
|
||||
|
||||
def test_plugins_list(self):
|
||||
cmd = 'python -m modelscope.cli.cli plugin list'
|
||||
stat, output = subprocess.getstatusoutput(cmd)
|
||||
self.assertEqual(stat, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.plugins import PluginsManager
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class AllowRemoteModelTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.package = 'moviepy'
|
||||
|
||||
def tearDown(self):
|
||||
# make sure uninstalled after installing
|
||||
uninstall_args = [self.package, '-y']
|
||||
PluginsManager.pip_command('uninstall', uninstall_args)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_bilibili_image(self):
|
||||
|
||||
model_path = snapshot_download(
|
||||
'bilibili/cv_bilibili_image-super-resolution', revision='v1.0.5')
|
||||
file_path = f'{model_path}/demos/title-compare1.png'
|
||||
weight_path = f'{model_path}/weights_v3/up2x-latest-denoise3x.pth'
|
||||
inference = pipeline(
|
||||
'image-super-resolution',
|
||||
model='bilibili/cv_bilibili_image-super-resolution',
|
||||
weight_path=weight_path,
|
||||
device='cpu',
|
||||
half=False) # GPU环境可以设置为True
|
||||
|
||||
output = inference(file_path, tile_mode=0, cache_mode=1, alpha=1)
|
||||
print(output)
|
||||
@@ -4,12 +4,20 @@ import unittest
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.plugins import PluginsManager
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
class PluginModelTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self):
|
||||
self.package = 'adaseq'
|
||||
|
||||
def tearDown(self):
|
||||
# make sure uninstalled after installing
|
||||
uninstall_args = [self.package, '-y']
|
||||
PluginsManager.pip_command('uninstall', uninstall_args)
|
||||
super().tearDown()
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
['pip', 'install', 'adaseq>=0.6.2', '--no-deps'],
|
||||
25
tests/utils/test_envs.py
Normal file
25
tests/utils/test_envs.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from modelscope.utils.plugins import EnvsManager
|
||||
|
||||
|
||||
class PluginTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_id = 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med'
|
||||
self.env_manager = EnvsManager(self.model_id)
|
||||
|
||||
def tearDown(self):
|
||||
self.env_manager.clean_env()
|
||||
super().tearDown()
|
||||
|
||||
def test_create_env(self):
|
||||
need_env = self.env_manager.check_if_need_env()
|
||||
self.assertEqual(need_env, True)
|
||||
activate_dir = self.env_manager.create_env()
|
||||
remote = 'source {}'.format(activate_dir)
|
||||
cmd = f'{remote};'
|
||||
print(cmd)
|
||||
# EnvsManager.run_process(cmd) no sh in ci env, so skip
|
||||
@@ -1,16 +1,29 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.plugins import (discover_plugins, import_all_plugins,
|
||||
import_file_plugins, import_plugins,
|
||||
pushd)
|
||||
from modelscope.utils.plugins import (PluginsManager, discover_plugins,
|
||||
import_all_plugins, import_file_plugins,
|
||||
import_plugins, pushd)
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class PluginTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.plugins_root = 'tests/utils/plugins/'
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.package = 'adaseq'
|
||||
self.plugins_manager = PluginsManager()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
def test_no_plugins(self):
|
||||
available_plugins = set(discover_plugins())
|
||||
@@ -39,3 +52,75 @@ class PluginTest(unittest.TestCase):
|
||||
|
||||
import_all_plugins()
|
||||
assert MODELS.get('dummy-model', 'dummy-group') is not None
|
||||
|
||||
def test_install_plugins(self):
|
||||
"""
|
||||
examples for the modelscope install method
|
||||
> modelscope install adaseq ofasys
|
||||
> modelscope install git+https://github.com/modelscope/AdaSeq.git
|
||||
> modelscope install adaseq -i <url> -f <url>
|
||||
> modelscope install adaseq --extra-index-url <url> --trusted-host <hostname>
|
||||
"""
|
||||
install_args = [self.package]
|
||||
status_code, install_args = self.plugins_manager.install_plugins(
|
||||
install_args)
|
||||
self.assertEqual(status_code, 0)
|
||||
|
||||
install_args = ['random_blabla']
|
||||
status_code, install_args = self.plugins_manager.install_plugins(
|
||||
install_args)
|
||||
self.assertEqual(status_code, 1)
|
||||
|
||||
install_args = [self.package, 'random_blabla']
|
||||
status_code, install_args = self.plugins_manager.install_plugins(
|
||||
install_args)
|
||||
self.assertEqual(status_code, 1)
|
||||
|
||||
# move this from tear down to avoid unexpected uninstall
|
||||
uninstall_args = [self.package, '-y']
|
||||
self.plugins_manager.uninstall_plugins(uninstall_args)
|
||||
|
||||
@unittest.skip
|
||||
def test_install_plugins_with_git(self):
|
||||
|
||||
install_args = ['git+https://github.com/modelscope/AdaSeq.git']
|
||||
status_code, install_args = self.plugins_manager.install_plugins(
|
||||
install_args)
|
||||
self.assertEqual(status_code, 0)
|
||||
|
||||
# move this from tear down to avoid unexpected uninstall
|
||||
uninstall_args = ['git+https://github.com/modelscope/AdaSeq.git', '-y']
|
||||
self.plugins_manager.uninstall_plugins(uninstall_args)
|
||||
|
||||
def test_uninstall_plugins(self):
|
||||
"""
|
||||
examples for the modelscope uninstall method
|
||||
> modelscope uninstall adaseq
|
||||
> modelscope uninstall -y adaseq
|
||||
"""
|
||||
install_args = [self.package]
|
||||
status_code, install_args = self.plugins_manager.install_plugins(
|
||||
install_args)
|
||||
self.assertEqual(status_code, 0)
|
||||
|
||||
uninstall_args = [self.package, '-y']
|
||||
status_code, uninstall_args = self.plugins_manager.uninstall_plugins(
|
||||
uninstall_args)
|
||||
self.assertEqual(status_code, 0)
|
||||
|
||||
def test_list_plugins(self):
|
||||
"""
|
||||
examples for the modelscope list method
|
||||
> modelscope list
|
||||
> modelscope list --all
|
||||
> modelscope list -a
|
||||
# """
|
||||
modelscope_plugin = os.path.join(self.tmp_dir, 'modelscope_plugin')
|
||||
self.plugins_manager.file_path = modelscope_plugin
|
||||
result = self.plugins_manager.list_plugins()
|
||||
self.assertEqual(len(result.items()), 0)
|
||||
|
||||
from modelscope.utils.plugins import OFFICIAL_PLUGINS
|
||||
|
||||
result = self.plugins_manager.list_plugins(show_all=True)
|
||||
self.assertEqual(len(result.items()), len(OFFICIAL_PLUGINS))
|
||||
|
||||
Reference in New Issue
Block a user