mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
1016 lines
34 KiB
Python
1016 lines
34 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
# This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
|
# Part of the implementation is borrowed from wimglenn/johnnydep
|
|
|
|
import copy
|
|
import importlib
|
|
import os
|
|
import pkgutil
|
|
import shutil
|
|
import sys
|
|
import venv
|
|
from contextlib import contextmanager
|
|
from fnmatch import fnmatch
|
|
from pathlib import Path
|
|
from typing import Any, Iterable, List, Optional, Set, Union
|
|
|
|
import json
|
|
import pkg_resources
|
|
|
|
from modelscope.fileio.file import LocalStorage
|
|
from modelscope.utils.ast_utils import FilesAstScanning
|
|
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
|
from modelscope.utils.file_utils import get_default_cache_dir
|
|
from modelscope.utils.hub import read_config, snapshot_download
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
logger = get_logger()
|
|
storage = LocalStorage()
|
|
|
|
MODELSCOPE_FILE_DIR = get_default_cache_dir()
|
|
PLUGINS_FILENAME = '.modelscope_plugins'
|
|
OFFICIAL_PLUGINS = [
|
|
{
|
|
'name': 'adaseq',
|
|
'desc':
|
|
'Provide hundreds of additions NERs algorithms, check: https://github.com/modelscope/AdaSeq',
|
|
'version': '',
|
|
'url': ''
|
|
},
|
|
]
|
|
|
|
LOCAL_PLUGINS_FILENAME = '.modelscope_plugins'
|
|
GLOBAL_PLUGINS_FILENAME = os.path.join(Path.home(), '.modelscope', 'plugins')
|
|
DEFAULT_PLUGINS = []
|
|
|
|
|
|
@contextmanager
|
|
def pushd(new_dir: str, verbose: bool = False):
|
|
"""
|
|
Changes the current directory to the given path and prepends it to `sys.path`.
|
|
This method is intended to use with `with`, so after its usage, the current
|
|
directory will be set to the previous value.
|
|
"""
|
|
previous_dir = os.getcwd()
|
|
if verbose:
|
|
logger.info(f'Changing directory to {new_dir}') # type: ignore
|
|
os.chdir(new_dir)
|
|
try:
|
|
yield
|
|
finally:
|
|
if verbose:
|
|
logger.info(f'Changing directory back to {previous_dir}')
|
|
os.chdir(previous_dir)
|
|
|
|
|
|
@contextmanager
|
|
def push_python_path(path: str):
|
|
"""
|
|
Prepends the given path to `sys.path`.
|
|
This method is intended to use with `with`, so after its usage, its value
|
|
will be removed from `sys.path`.
|
|
"""
|
|
path = Path(path).resolve()
|
|
path = str(path)
|
|
sys.path.insert(0, path)
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.path.remove(path)
|
|
|
|
|
|
def discover_file_plugins(
|
|
filename: str = LOCAL_PLUGINS_FILENAME) -> Iterable[str]:
|
|
"""
|
|
Discover plugins from file
|
|
"""
|
|
with open(filename) as f:
|
|
for module_name in f:
|
|
module_name = module_name.strip()
|
|
if module_name:
|
|
yield module_name
|
|
|
|
|
|
def discover_plugins(requirement_path=None) -> Iterable[str]:
|
|
"""
|
|
Discover plugins
|
|
|
|
Args:
|
|
requirement_path: The file path of requirement
|
|
|
|
"""
|
|
plugins: Set[str] = set()
|
|
if requirement_path is None:
|
|
if os.path.isfile(LOCAL_PLUGINS_FILENAME):
|
|
with push_python_path('.'):
|
|
for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME):
|
|
if plugin in plugins:
|
|
continue
|
|
yield plugin
|
|
plugins.add(plugin)
|
|
if os.path.isfile(GLOBAL_PLUGINS_FILENAME):
|
|
for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME):
|
|
if plugin in plugins:
|
|
continue
|
|
yield plugin
|
|
plugins.add(plugin)
|
|
else:
|
|
if os.path.isfile(requirement_path):
|
|
for plugin in discover_file_plugins(requirement_path):
|
|
if plugin in plugins:
|
|
continue
|
|
yield plugin
|
|
plugins.add(plugin)
|
|
|
|
|
|
def import_all_plugins(plugins: List[str] = None) -> List[str]:
|
|
"""
|
|
Imports default plugins, input plugins and file discovered plugins.
|
|
"""
|
|
import_module_and_submodules(
|
|
'modelscope',
|
|
include={
|
|
'modelscope.metrics.builder',
|
|
'modelscope.models.builder',
|
|
'modelscope.pipelines.builder',
|
|
'modelscope.preprocessors.builder',
|
|
'modelscope.trainers.builder',
|
|
},
|
|
exclude={
|
|
'modelscope.metrics.*',
|
|
'modelscope.models.*',
|
|
'modelscope.pipelines.*',
|
|
'modelscope.preprocessors.*',
|
|
'modelscope.trainers.*',
|
|
'modelscope.msdatasets',
|
|
'modelscope.utils',
|
|
'modelscope.exporters',
|
|
})
|
|
|
|
imported_plugins: List[str] = []
|
|
|
|
imported_plugins.extend(import_plugins(DEFAULT_PLUGINS))
|
|
imported_plugins.extend(import_plugins(plugins))
|
|
imported_plugins.extend(import_file_plugins())
|
|
|
|
return imported_plugins
|
|
|
|
|
|
def import_plugins(plugins: List[str] = None) -> List[str]:
|
|
"""
|
|
Imports the plugins listed in the arguments.
|
|
"""
|
|
imported_plugins: List[str] = []
|
|
if plugins is None or len(plugins) == 0:
|
|
return imported_plugins
|
|
|
|
# Workaround for a presumed Python issue where spawned processes can't find modules in the current directory.
|
|
cwd = os.getcwd()
|
|
if cwd not in sys.path:
|
|
sys.path.append(cwd)
|
|
|
|
for module_name in plugins:
|
|
try:
|
|
import_module_and_submodules(module_name)
|
|
logger.info('Plugin %s available', module_name)
|
|
imported_plugins.append(module_name)
|
|
except ModuleNotFoundError as e:
|
|
logger.error(f'Plugin {module_name} could not be loaded: {e}')
|
|
|
|
return imported_plugins
|
|
|
|
|
|
def import_file_plugins(requirement_path=None) -> List[str]:
|
|
"""
|
|
Imports the plugins found with `discover_plugins()`.
|
|
|
|
Args:
|
|
requirement_path: The file path of requirement
|
|
|
|
"""
|
|
imported_plugins: List[str] = []
|
|
|
|
# Workaround for a presumed Python issue where spawned processes can't find modules in the current directory.
|
|
cwd = os.getcwd()
|
|
if cwd not in sys.path:
|
|
sys.path.append(cwd)
|
|
|
|
for module_name in discover_plugins(requirement_path):
|
|
try:
|
|
importlib.import_module(module_name)
|
|
logger.info('Plugin %s available', module_name)
|
|
imported_plugins.append(module_name)
|
|
except ModuleNotFoundError as e:
|
|
logger.error(f'Plugin {module_name} could not be loaded: {e}')
|
|
|
|
return imported_plugins
|
|
|
|
|
|
def import_module_and_submodules(package_name: str,
|
|
include: Optional[Set[str]] = None,
|
|
exclude: Optional[Set[str]] = None) -> None:
|
|
"""
|
|
Import all public submodules under the given package.
|
|
"""
|
|
# take care of None
|
|
include = include if include else set()
|
|
exclude = exclude if exclude else set()
|
|
|
|
def fn_in(package_name: str, pattern_set: Set[str]) -> bool:
|
|
for pattern in pattern_set:
|
|
if fnmatch(package_name, pattern):
|
|
return True
|
|
return False
|
|
|
|
if not fn_in(package_name, include) and fn_in(package_name, exclude):
|
|
return
|
|
|
|
importlib.invalidate_caches()
|
|
|
|
# For some reason, python doesn't always add this by default to your path, but you pretty much
|
|
# always want it when using `--include-package`. And if it's already there, adding it again at
|
|
# the end won't hurt anything.
|
|
with push_python_path('.'):
|
|
# Import at top level
|
|
try:
|
|
module = importlib.import_module(package_name)
|
|
path = getattr(module, '__path__', [])
|
|
path_string = '' if not path else path[0]
|
|
|
|
# walk_packages only finds immediate children, so need to recurse.
|
|
for module_finder, name, _ in pkgutil.walk_packages(path):
|
|
# Sometimes when you import third-party libraries that are on your path,
|
|
# `pkgutil.walk_packages` returns those too, so we need to skip them.
|
|
if path_string and module_finder.path != path_string: # type: ignore[union-attr]
|
|
continue
|
|
if name.startswith('_'):
|
|
# skip directly importing private subpackages
|
|
continue
|
|
if name.startswith('test'):
|
|
# skip tests
|
|
continue
|
|
subpackage = f'{package_name}.{name}'
|
|
import_module_and_submodules(subpackage, exclude=exclude)
|
|
except SystemExit as e:
|
|
# this case is specific for easy_cv's tools/predict.py exit
|
|
logger.warning(
|
|
f'{package_name} not imported: {str(e)}, but should continue')
|
|
pass
|
|
except Exception as e:
|
|
logger.warning(f'{package_name} not imported: {str(e)}')
|
|
if len(package_name.split('.')) == 1:
|
|
raise ModuleNotFoundError('Package not installed')
|
|
|
|
|
|
def install_module_from_requirements(requirement_path, ):
|
|
"""
|
|
Args:
|
|
requirement_path: The path of requirement file
|
|
|
|
Returns:
|
|
|
|
"""
|
|
|
|
install_list = []
|
|
with open(requirement_path, 'r', encoding='utf-8') as f:
|
|
requirements = f.read().splitlines()
|
|
for req in requirements:
|
|
if req == '':
|
|
continue
|
|
installed, _ = PluginsManager.check_plugin_installed(req)
|
|
if not installed:
|
|
install_list.append(req)
|
|
|
|
if len(install_list) > 0:
|
|
status_code, _, args = PluginsManager.pip_command(
|
|
'install',
|
|
install_list,
|
|
)
|
|
if status_code != 0:
|
|
raise ImportError(
|
|
f'Failed to install requirements from {requirement_path}')
|
|
|
|
|
|
def import_module_from_file(module_name, file_path):
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def import_module_from_model_dir(model_dir):
|
|
from pathlib import Path
|
|
file_scanner = FilesAstScanning()
|
|
file_scanner.traversal_files(model_dir)
|
|
file_dirs = file_scanner.file_dirs
|
|
requirements = file_scanner.requirement_dirs
|
|
|
|
# install the requirements firstly
|
|
install_requirements_by_files(requirements)
|
|
|
|
# then import the modules
|
|
import sys
|
|
sys.path.insert(0, model_dir)
|
|
for file in file_dirs:
|
|
module_name = Path(file).stem
|
|
import_module_from_file(module_name, file)
|
|
|
|
|
|
def install_requirements_by_names(plugins: List[str]):
|
|
plugins_manager = PluginsManager()
|
|
uninstalled_plugins = []
|
|
for plugin in plugins:
|
|
plugin_installed, version = plugins_manager.check_plugin_installed(
|
|
plugin)
|
|
if not plugin_installed:
|
|
uninstalled_plugins.append(plugin)
|
|
status, _ = plugins_manager.install_plugins(uninstalled_plugins)
|
|
if status != 0:
|
|
raise EnvironmentError(
|
|
f'The required packages {",".join(uninstalled_plugins)} are not installed.',
|
|
f'Please run the command `modelscope plugin install {" ".join(uninstalled_plugins)}` to install them.'
|
|
)
|
|
|
|
|
|
def install_requirements_by_files(requirements: List[str]):
|
|
for requirement in requirements:
|
|
install_module_from_requirements(requirement)
|
|
|
|
|
|
def register_plugins_repo(plugins: List[str]) -> None:
|
|
""" Try to install and import plugins from repo"""
|
|
if plugins is not None:
|
|
install_requirements_by_names(plugins)
|
|
modules = []
|
|
for plugin in plugins:
|
|
modules.extend(get_modules_from_package(plugin))
|
|
import_plugins(modules)
|
|
|
|
|
|
def register_modelhub_repo(model_dir, allow_remote=False) -> None:
|
|
""" Try to install and import remote model from modelhub"""
|
|
if allow_remote:
|
|
try:
|
|
import_module_from_model_dir(model_dir)
|
|
except KeyError:
|
|
logger.warning(
|
|
'Multi component keys in the hub are registered in same file')
|
|
pass
|
|
|
|
|
|
DEFAULT_INDEX = 'https://pypi.org/simple/'
|
|
|
|
|
|
def get_modules_from_package(package):
|
|
""" to get the modules from a installed package
|
|
|
|
Args:
|
|
package: The distribution name or package name
|
|
|
|
Returns:
|
|
|
|
"""
|
|
from zipfile import ZipFile
|
|
from tempfile import mkdtemp
|
|
from subprocess import check_output, STDOUT
|
|
from glob import glob
|
|
import hashlib
|
|
from urllib.parse import urlparse
|
|
from urllib import request as urllib2
|
|
from pip._internal.utils.packaging import get_requirement
|
|
req = get_requirement(package)
|
|
package = req.name
|
|
|
|
def urlretrieve(url, filename, data=None, auth=None):
|
|
if auth is not None:
|
|
# https://docs.python.org/2.7/howto/urllib2.html#id6
|
|
password_mgr = urllib2.HTTPPasswordMgrWithDefaultRealm()
|
|
|
|
# Add the username and password.
|
|
# If we knew the realm, we could use it instead of None.
|
|
username, password = auth
|
|
top_level_url = urlparse(url).netloc
|
|
password_mgr.add_password(None, top_level_url, username, password)
|
|
|
|
handler = urllib2.HTTPBasicAuthHandler(password_mgr)
|
|
|
|
# create "opener" (OpenerDirector instance)
|
|
opener = urllib2.build_opener(handler)
|
|
else:
|
|
opener = urllib2.build_opener()
|
|
|
|
res = opener.open(url, data=data)
|
|
|
|
headers = res.info()
|
|
|
|
with open(filename, 'wb') as fp:
|
|
fp.write(res.read())
|
|
|
|
return filename, headers
|
|
|
|
def compute_checksum(target, algorithm='sha256', blocksize=2**13):
|
|
hashtype = getattr(hashlib, algorithm)
|
|
hash_ = hashtype()
|
|
logger.debug('computing checksum', target=target, algorithm=algorithm)
|
|
with open(target, 'rb') as f:
|
|
for chunk in iter(lambda: f.read(blocksize), b''):
|
|
hash_.update(chunk)
|
|
result = hash_.hexdigest()
|
|
logger.debug('computed checksum', result=result)
|
|
return result
|
|
|
|
def _get_pip_version():
|
|
# try to get pip version without actually importing pip
|
|
# setuptools gets upset if you import pip before importing setuptools..
|
|
try:
|
|
import importlib.metadata # Python 3.8+
|
|
return importlib.metadata.version('pip')
|
|
except Exception:
|
|
pass
|
|
import pip
|
|
return pip.__version__
|
|
|
|
def _download_dist(url, scratch_file, index_url, extra_index_url):
|
|
auth = None
|
|
if index_url:
|
|
parsed = urlparse(index_url)
|
|
if parsed.username and parsed.password and parsed.hostname == urlparse(
|
|
url).hostname:
|
|
# handling private PyPI credentials in index_url
|
|
auth = (parsed.username, parsed.password)
|
|
if extra_index_url:
|
|
parsed = urlparse(extra_index_url)
|
|
if parsed.username and parsed.password and parsed.hostname == urlparse(
|
|
url).hostname:
|
|
# handling private PyPI credentials in extra_index_url
|
|
auth = (parsed.username, parsed.password)
|
|
target, _headers = urlretrieve(url, scratch_file, auth=auth)
|
|
return target, _headers
|
|
|
|
def _get_wheel_args(index_url, env, extra_index_url):
|
|
args = [
|
|
sys.executable,
|
|
'-m',
|
|
'pip',
|
|
'wheel',
|
|
'-vvv', # --verbose x3
|
|
'--no-deps',
|
|
'--no-cache-dir',
|
|
'--disable-pip-version-check',
|
|
]
|
|
if index_url is not None:
|
|
args += ['--index-url', index_url]
|
|
if index_url != DEFAULT_INDEX:
|
|
hostname = urlparse(index_url).hostname
|
|
if hostname:
|
|
args += ['--trusted-host', hostname]
|
|
if extra_index_url is not None:
|
|
args += [
|
|
'--extra-index-url', extra_index_url, '--trusted-host',
|
|
urlparse(extra_index_url).hostname
|
|
]
|
|
if env is None:
|
|
pip_version = _get_pip_version()
|
|
else:
|
|
pip_version = dict(env)['pip_version']
|
|
args[0] = dict(env)['python_executable']
|
|
pip_major, pip_minor = pip_version.split('.')[0:2]
|
|
pip_major = int(pip_major)
|
|
pip_minor = int(pip_minor)
|
|
if pip_major >= 10:
|
|
args.append('--progress-bar=off')
|
|
if (20, 3) <= (pip_major, pip_minor) < (21, 1):
|
|
# See https://github.com/pypa/pip/issues/9139#issuecomment-735443177
|
|
args.append('--use-deprecated=legacy-resolver')
|
|
return args
|
|
|
|
def get(dist_name,
|
|
index_url=None,
|
|
env=None,
|
|
extra_index_url=None,
|
|
tmpdir=None,
|
|
ignore_errors=False):
|
|
args = _get_wheel_args(index_url, env, extra_index_url) + [dist_name]
|
|
scratch_dir = mkdtemp(dir=tmpdir)
|
|
logger.debug(
|
|
'wheeling and dealing',
|
|
scratch_dir=os.path.abspath(scratch_dir),
|
|
args=' '.join(args))
|
|
try:
|
|
out = check_output(
|
|
args, stderr=STDOUT, cwd=scratch_dir).decode('utf-8')
|
|
except ChildProcessError as err:
|
|
out = getattr(err, 'output', b'').decode('utf-8')
|
|
logger.warning(out)
|
|
if not ignore_errors:
|
|
raise
|
|
logger.debug('wheel command completed ok', dist_name=dist_name)
|
|
links = []
|
|
local_links = []
|
|
lines = out.splitlines()
|
|
for i, line in enumerate(lines):
|
|
line = line.strip()
|
|
if line.startswith('Downloading from URL '):
|
|
parts = line.split()
|
|
link = parts[3]
|
|
links.append(link)
|
|
elif line.startswith('Downloading '):
|
|
parts = line.split()
|
|
last = parts[-1]
|
|
if len(parts) == 3 and last.startswith('(') and last.endswith(
|
|
')'):
|
|
link = parts[-2]
|
|
elif len(parts) == 4 and parts[-2].startswith(
|
|
'(') and last.endswith(')'):
|
|
link = parts[-3]
|
|
if not urlparse(link).scheme:
|
|
# newest pip versions have changed to not log the full url
|
|
# in the download event. it is becoming more and more annoying
|
|
# to preserve compatibility across a wide range of pip versions
|
|
next_line = lines[i + 1].strip()
|
|
if next_line.startswith(
|
|
'Added ') and ' to build tracker' in next_line:
|
|
link = next_line.split(
|
|
' to build tracker')[0].split()[-1]
|
|
else:
|
|
link = last
|
|
links.append(link)
|
|
elif line.startswith(
|
|
'Source in ') and 'which satisfies requirement' in line:
|
|
link = line.split()[-1]
|
|
links.append(link)
|
|
elif line.startswith('Added ') and ' from file://' in line:
|
|
[link] = [x for x in line.split() if x.startswith('file://')]
|
|
local_links.append(link)
|
|
if not links:
|
|
# prefer http scheme over file
|
|
links += local_links
|
|
links = list(dict.fromkeys(links)) # order-preserving dedupe
|
|
if not links:
|
|
logger.warning('could not find download link', out=out)
|
|
raise Exception('failed to collect dist')
|
|
if len(links) == 2:
|
|
# sometimes we collect the same link, once with a url fragment/checksum and once without
|
|
first, second = links
|
|
if first.startswith(second):
|
|
del links[1]
|
|
elif second.startswith(first):
|
|
del links[0]
|
|
if len(links) > 1:
|
|
logger.debug('more than 1 link collected', out=out, links=links)
|
|
# Since PEP 517, maybe an sdist will also need to collect other distributions
|
|
# for the build system, even with --no-deps specified. pendulum==1.4.4 is one
|
|
# example, which uses poetry and doesn't publish any python37 wheel to PyPI.
|
|
# However, the dist itself should still be the first one downloaded.
|
|
link = links[0]
|
|
whls = glob(os.path.join(os.path.abspath(scratch_dir), '*.whl'))
|
|
try:
|
|
[whl] = whls
|
|
except ValueError:
|
|
if ignore_errors:
|
|
whl = ''
|
|
else:
|
|
raise
|
|
url, _sep, checksum = link.partition('#')
|
|
url = url.replace(
|
|
'/%2Bf/', '/+f/'
|
|
) # some versions of pip did not unquote this fragment in the log
|
|
if not checksum.startswith('md5=') and not checksum.startswith(
|
|
'sha256='):
|
|
# PyPI gives you the checksum in url fragment, as a convenience. But not all indices are so kind.
|
|
algorithm = 'md5'
|
|
if os.path.basename(whl).lower() == url.rsplit('/', 1)[-1].lower():
|
|
target = whl
|
|
else:
|
|
scratch_file = os.path.join(scratch_dir, os.path.basename(url))
|
|
target, _headers = _download_dist(url, scratch_file, index_url,
|
|
extra_index_url)
|
|
checksum = compute_checksum(target=target, algorithm=algorithm)
|
|
checksum = '='.join([algorithm, checksum])
|
|
result = {'path': whl, 'url': url, 'checksum': checksum}
|
|
return result
|
|
|
|
def discover_import_names(whl_file):
|
|
logger.debug('finding import names')
|
|
zipfile = ZipFile(file=whl_file)
|
|
namelist = zipfile.namelist()
|
|
[top_level_fname
|
|
] = [x for x in namelist if x.endswith('top_level.txt')]
|
|
all_names = zipfile.read(top_level_fname).decode(
|
|
'utf-8').strip().splitlines()
|
|
public_names = [n for n in all_names if not n.startswith('_')]
|
|
return public_names
|
|
|
|
tmpdir = mkdtemp()
|
|
data = get(package, tmpdir=tmpdir)
|
|
import_names = discover_import_names(data['path'])
|
|
shutil.rmtree(tmpdir)
|
|
return import_names
|
|
|
|
|
|
class PluginsManager(object):
|
|
|
|
def __init__(self,
|
|
cache_dir=MODELSCOPE_FILE_DIR,
|
|
plugins_file=PLUGINS_FILENAME):
|
|
cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
|
|
plugins_file = os.getenv('MODELSCOPE_PLUGINS_FILE', plugins_file)
|
|
self._file_path = os.path.join(cache_dir, plugins_file)
|
|
|
|
@property
|
|
def file_path(self):
|
|
return self._file_path
|
|
|
|
@file_path.setter
|
|
def file_path(self, value):
|
|
self._file_path = value
|
|
|
|
@staticmethod
|
|
def check_plugin_installed(package):
|
|
""" Check if the plugin is installed, and if the version is valid
|
|
|
|
Args:
|
|
package: the package name need to be installed
|
|
|
|
Returns:
|
|
|
|
"""
|
|
|
|
if package.split('.')[-1] == 'whl':
|
|
return False, ''
|
|
|
|
from pip._internal.utils.packaging import get_requirement, specifiers
|
|
req = get_requirement(package)
|
|
|
|
try:
|
|
importlib.reload(pkg_resources)
|
|
package_meta_info = pkg_resources.working_set.by_key[req.name]
|
|
version = package_meta_info.version
|
|
|
|
# To test if the package is installed
|
|
installed = True
|
|
|
|
# If installed, test if the version is correct
|
|
for spec in req.specifier:
|
|
installed_valid_version = spec.contains(version)
|
|
if not installed_valid_version:
|
|
installed = False
|
|
break
|
|
except KeyError:
|
|
version = ''
|
|
installed = False
|
|
|
|
return installed, version
|
|
|
|
@staticmethod
|
|
def pip_command(
|
|
command,
|
|
command_args: List[str],
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
command: install, uninstall command
|
|
command_args: the args to be used with command, should be in list
|
|
such as ['-r', 'requirements']
|
|
|
|
Returns:
|
|
|
|
"""
|
|
from pip._internal.commands import create_command
|
|
importlib.reload(pkg_resources)
|
|
if command == 'install':
|
|
command_args.append('-f')
|
|
command_args.append(
|
|
'https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html'
|
|
)
|
|
command = create_command(command)
|
|
options, args = command.parse_args(command_args)
|
|
|
|
status_code = command.main(command_args)
|
|
|
|
# reload the pkg_resources in order to get the latest pkgs information
|
|
importlib.reload(pkg_resources)
|
|
|
|
return status_code, options, args
|
|
|
|
def install_plugins(self,
|
|
install_args: List[str],
|
|
index_url: Optional[str] = None,
|
|
force_update=False) -> Any:
|
|
"""Install packages via pip
|
|
Args:
|
|
install_args (list): List of arguments passed to `pip install`.
|
|
index_url (str, optional): The pypi index url.
|
|
"""
|
|
|
|
if len(install_args) == 0:
|
|
return 0, []
|
|
|
|
if index_url is not None:
|
|
install_args += ['-i', index_url]
|
|
|
|
if force_update is not False:
|
|
install_args += ['-f']
|
|
|
|
status_code, options, args = PluginsManager.pip_command(
|
|
'install',
|
|
install_args,
|
|
)
|
|
|
|
if status_code == 0:
|
|
logger.info(f'The plugins {",".join(args)} is installed')
|
|
|
|
# TODO Add Ast index for ast update record
|
|
|
|
# Add the plugins info to the local record
|
|
installed_package = self.parse_args_info(args, options)
|
|
self.update_plugins_file(installed_package)
|
|
|
|
return status_code, install_args
|
|
|
|
def parse_args_info(self, args: List[str], options):
|
|
installed_package = []
|
|
|
|
# the case of install with requirements
|
|
if len(args) == 0:
|
|
src_dir = options.src_dir
|
|
requirements = options.requirments
|
|
for requirement in requirements:
|
|
package_info = {
|
|
'name': requirement,
|
|
'url': os.path.join(src_dir, requirement),
|
|
'desc': '',
|
|
'version': ''
|
|
}
|
|
|
|
installed_package.append(package_info)
|
|
|
|
def get_package_info(package_name):
|
|
from pathlib import Path
|
|
package_info = {
|
|
'name': package_name,
|
|
'url': options.index_url,
|
|
'desc': ''
|
|
}
|
|
|
|
# the case with git + http
|
|
if package_name.split('.')[-1] == 'git':
|
|
package_name = Path(package_name).stem
|
|
|
|
plugin_installed, version = self.check_plugin_installed(
|
|
package_name)
|
|
if plugin_installed:
|
|
package_info['version'] = version
|
|
package_info['name'] = package_name
|
|
else:
|
|
logger.warning(
|
|
f'The package {package_name} is not in the lib, this might be happened'
|
|
f' when installing the package with git+https method, should be ignored'
|
|
)
|
|
package_info['version'] = ''
|
|
|
|
return package_info
|
|
|
|
for package in args:
|
|
package_info = get_package_info(package)
|
|
installed_package.append(package_info)
|
|
|
|
return installed_package
|
|
|
|
def uninstall_plugins(self,
|
|
uninstall_args: Union[str, List],
|
|
is_yes=False):
|
|
if is_yes is not None:
|
|
uninstall_args += ['-y']
|
|
|
|
status_code, options, args = PluginsManager.pip_command(
|
|
'uninstall',
|
|
uninstall_args,
|
|
)
|
|
|
|
if status_code == 0:
|
|
logger.info(f'The plugins {",".join(args)} is uninstalled')
|
|
|
|
# TODO Add Ast index for ast update record
|
|
|
|
# Add to the local record
|
|
self.remove_plugins_from_file(args)
|
|
|
|
return status_code, uninstall_args
|
|
|
|
def _get_plugins_from_file(self):
|
|
""" get plugins from file
|
|
|
|
"""
|
|
logger.info(f'Loading plugins information from {self.file_path}')
|
|
if os.path.exists(self.file_path):
|
|
local_plugins_info_bytes = storage.read(self.file_path)
|
|
local_plugins_info = json.loads(local_plugins_info_bytes)
|
|
else:
|
|
local_plugins_info = {}
|
|
return local_plugins_info
|
|
|
|
def _update_plugins(
|
|
self,
|
|
new_plugins_list,
|
|
local_plugins_info,
|
|
override=False,
|
|
):
|
|
for item in new_plugins_list:
|
|
package_name = item.pop('name')
|
|
|
|
# update package information if existed
|
|
if package_name in local_plugins_info and not override:
|
|
original_item = local_plugins_info[package_name]
|
|
from pkg_resources import parse_version
|
|
item_version = parse_version(
|
|
item['version'] if item['version'] != '' else '0.0.0')
|
|
origin_version = parse_version(
|
|
original_item['version']
|
|
if original_item['version'] != '' else '0.0.0')
|
|
desc = item['desc']
|
|
if original_item['desc'] != '' and desc == '':
|
|
desc = original_item['desc']
|
|
item = item if item_version > origin_version else original_item
|
|
item['desc'] = desc
|
|
|
|
# Double-check if the item is installed with the version number
|
|
if item['version'] == '':
|
|
plugin_installed, version = self.check_plugin_installed(
|
|
package_name)
|
|
item['version'] = version
|
|
|
|
local_plugins_info[package_name] = item
|
|
|
|
return local_plugins_info
|
|
|
|
def _print_plugins_info(self, local_plugins_info):
|
|
print('{:<15} |{:<10} |{:<100}'.format('NAME', 'VERSION',
|
|
'DESCRIPTION'))
|
|
print('')
|
|
for k, v in local_plugins_info.items():
|
|
print('{:<15} |{:<10} |{:<100}'.format(k, v['version'], v['desc']))
|
|
|
|
def list_plugins(
|
|
self,
|
|
show_all=False,
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
show_all: show installed and official supported if True, else only those installed
|
|
|
|
Returns:
|
|
|
|
"""
|
|
local_plugins_info = self._get_plugins_from_file()
|
|
|
|
# update plugins with default
|
|
|
|
local_official_plugins = copy.deepcopy(OFFICIAL_PLUGINS)
|
|
local_plugins_info = self._update_plugins(local_official_plugins,
|
|
local_plugins_info)
|
|
|
|
if show_all is True:
|
|
self._print_plugins_info(local_plugins_info)
|
|
return local_plugins_info
|
|
|
|
# Consider those package with version is installed
|
|
not_installed_list = []
|
|
for item in local_plugins_info:
|
|
if local_plugins_info[item]['version'] == '':
|
|
not_installed_list.append(item)
|
|
|
|
for item in not_installed_list:
|
|
local_plugins_info.pop(item)
|
|
|
|
self._print_plugins_info(local_plugins_info)
|
|
return local_plugins_info
|
|
|
|
def update_plugins_file(
|
|
self,
|
|
plugins_list,
|
|
override=False,
|
|
):
|
|
"""update the plugins file in order to maintain the latest plugins information
|
|
|
|
Args:
|
|
plugins_list: The plugins list contain the information of plugins
|
|
name, version, introduction, install url and the status of delete or update
|
|
override: Override the file by the list if True, else only update.
|
|
|
|
Returns:
|
|
|
|
"""
|
|
local_plugins_info = self._get_plugins_from_file()
|
|
|
|
# local_plugins_info is empty if first time loading, should add OFFICIAL_PLUGINS information
|
|
if local_plugins_info == {}:
|
|
plugins_list.extend(copy.deepcopy(OFFICIAL_PLUGINS))
|
|
|
|
local_plugins_info = self._update_plugins(plugins_list,
|
|
local_plugins_info, override)
|
|
|
|
local_plugins_info_json = json.dumps(local_plugins_info)
|
|
storage.write(local_plugins_info_json.encode(), self.file_path)
|
|
|
|
return local_plugins_info_json
|
|
|
|
def remove_plugins_from_file(
|
|
self,
|
|
package_names: Union[str, list],
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
package_names: package name
|
|
|
|
Returns:
|
|
|
|
"""
|
|
local_plugins_info = self._get_plugins_from_file()
|
|
|
|
if type(package_names) is str:
|
|
package_names = list(package_names)
|
|
|
|
for item in package_names:
|
|
if item in local_plugins_info:
|
|
local_plugins_info.pop(item)
|
|
|
|
local_plugins_info_json = json.dumps(local_plugins_info)
|
|
storage.write(local_plugins_info_json.encode(), self.file_path)
|
|
|
|
return local_plugins_info_json
|
|
|
|
|
|
class EnvsManager(object):
|
|
name = 'envs'
|
|
|
|
def __init__(self,
|
|
model_id,
|
|
model_revision=DEFAULT_MODEL_REVISION,
|
|
cache_dir=MODELSCOPE_FILE_DIR):
|
|
"""
|
|
|
|
Args:
|
|
model_id: id of the model, not dir
|
|
model_revision: revision of the model, default as master
|
|
cache_dir: the system modelscope cache dir
|
|
"""
|
|
cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
|
|
self.env_dir = os.path.join(cache_dir, EnvsManager.name, model_id)
|
|
model_dir = snapshot_download(model_id, revision=model_revision)
|
|
cfg = read_config(model_dir)
|
|
self.plugins = cfg.get('plugins', [])
|
|
self.allow_remote = cfg.get('allow_remote', False)
|
|
self.env_builder = venv.EnvBuilder(
|
|
system_site_packages=True,
|
|
clear=False,
|
|
symlinks=True,
|
|
with_pip=False)
|
|
|
|
def get_env_dir(self):
|
|
return self.env_dir
|
|
|
|
def get_activate_dir(self):
|
|
return os.path.join(self.env_dir, 'bin', 'activate')
|
|
|
|
def check_if_need_env(self):
|
|
if len(self.plugins) or self.allow_remote:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def create_env(self):
|
|
if not os.path.exists(self.env_dir):
|
|
os.makedirs(self.env_dir)
|
|
try:
|
|
self.env_builder.create(self.env_dir)
|
|
except Exception as e:
|
|
self.clean_env()
|
|
raise EnvironmentError(
|
|
f'Failed to create virtual env at {self.env_dir} with error: {e}'
|
|
)
|
|
|
|
def clean_env(self):
|
|
if os.path.exists(self.env_dir):
|
|
self.env_builder.clear_directory(self.env_dir)
|
|
|
|
@staticmethod
|
|
def run_process(cmd):
|
|
import subprocess
|
|
status, result = subprocess.getstatusoutput(cmd)
|
|
logger.debug('The status and the results are: {}, {}'.format(
|
|
status, result))
|
|
if status != 0:
|
|
raise Exception(
|
|
'running the cmd: {} failed, with message: {}'.format(
|
|
cmd, result))
|
|
return result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
install_requirements_by_files(['adaseq'])
|
|
import_name = get_modules_from_package('pai-easycv')
|