diff --git a/modelscope/utils/plugins.py b/modelscope/utils/plugins.py index a83ca03c..9d238e7d 100644 --- a/modelscope/utils/plugins.py +++ b/modelscope/utils/plugins.py @@ -263,12 +263,11 @@ def import_module_and_submodules(package_name: str, def install_module_from_requirements(requirement_path, ): - """ + """ install module from requirements Args: requirement_path: The path of requirement file - Returns: - + No returns, raise error if failed """ install_list = [] @@ -292,6 +291,15 @@ def install_module_from_requirements(requirement_path, ): def import_module_from_file(module_name, file_path): + """ install module by name with file path + + Args: + module_name: the module name need to be import + file_path: the related file path that matched with the module name + + Returns: return the module class + + """ spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -299,6 +307,14 @@ def import_module_from_file(module_name, file_path): def import_module_from_model_dir(model_dir): + """ import all the necessary module from a model dir + + Args: + model_dir: model file location + + No returns, raise error if failed + + """ from pathlib import Path file_scanner = FilesAstScanning() file_scanner.traversal_files(model_dir) @@ -317,6 +333,14 @@ def import_module_from_model_dir(model_dir): def install_requirements_by_names(plugins: List[str]): + """ install the requirements by names + + Args: + plugins: name of plugins (pai-easyscv, transformers) + + No returns, raise error if failed + + """ plugins_manager = PluginsManager() uninstalled_plugins = [] for plugin in plugins: @@ -333,6 +357,14 @@ def install_requirements_by_names(plugins: List[str]): def install_requirements_by_files(requirements: List[str]): + """ install the requriements by files + + Args: + requirements: a list of files including requirements info (requirements.txt) + + No returns, raise error if failed + + """ for requirement in requirements: install_module_from_requirements(requirement) @@ -343,7 +375,8 @@ def register_plugins_repo(plugins: List[str]) -> None: install_requirements_by_names(plugins) modules = [] for plugin in plugins: - modules.extend(get_modules_from_package(plugin)) + module_name, module_version, _ = get_modules_from_package(plugin) + modules.extend(module_name) import_plugins(modules) @@ -362,12 +395,15 @@ DEFAULT_INDEX = 'https://pypi.org/simple/' def get_modules_from_package(package): - """ to get the modules from a installed package + """ to get the modules from an installed package Args: package: The distribution name or package name Returns: + import_names: The modules that in the package distribution + import_version: The version of those modules, should be same and identical + package_name: The package name, if installed by whl file, the package is unknown, should be passed """ from zipfile import ZipFile @@ -378,8 +414,6 @@ def get_modules_from_package(package): from urllib.parse import urlparse from urllib import request as urllib2 from pip._internal.utils.packaging import get_requirement - req = get_requirement(package) - package = req.name def urlretrieve(url, filename, data=None, auth=None): if auth is not None: @@ -591,24 +625,58 @@ def get_modules_from_package(package): return result def discover_import_names(whl_file): + import re logger.debug('finding import names') zipfile = ZipFile(file=whl_file) namelist = zipfile.namelist() [top_level_fname ] = [x for x in namelist if x.endswith('top_level.txt')] + [metadata_fname + ] = [x for x in namelist if x.endswith('.dist-info/METADATA')] all_names = zipfile.read(top_level_fname).decode( 'utf-8').strip().splitlines() + metadata = zipfile.read(metadata_fname).decode('utf-8') public_names = [n for n in all_names if not n.startswith('_')] - return public_names + + version_pattern = re.compile(r'^Version: (?P.+)$', + re.MULTILINE) + name_pattern = re.compile(r'^Name: (?P.+)$', re.MULTILINE) + + version_match = version_pattern.search(metadata) + name_match = name_pattern.search(metadata) + + module_version = version_match.group('version') + module_name = name_match.group('name') + + return public_names, module_version, module_name tmpdir = mkdtemp() - data = get(package, tmpdir=tmpdir) - import_names = discover_import_names(data['path']) + if package.endswith('.whl'): + """if user using .whl file then parse the whl to get the module name""" + if not os.path.isfile(package): + file_name = os.path.basename(package) + file_path = os.path.join(tmpdir, file_name) + whl_file, _ = _download_dist(package, file_path, None, None) + else: + whl_file = package + else: + """if user using package name then generate whl file and parse the file to get the module name by + the discover_import_names method + """ + req = get_requirement(package) + package = req.name + data = get(package, tmpdir=tmpdir) + whl_file = data['path'] + import_names, import_version, package_name = discover_import_names( + whl_file) shutil.rmtree(tmpdir) - return import_names + return import_names, import_version, package_name class PluginsManager(object): + """ + plugins manager class + """ def __init__(self, cache_dir=MODELSCOPE_FILE_DIR, @@ -633,12 +701,26 @@ class PluginsManager(object): package: the package name need to be installed Returns: + if_installed: True if installed + version: the version of installed or None if not installed """ if package.split('.')[-1] == 'whl': - return False, '' + # install from whl should test package name instead of module name + _, module_version, package_name = get_modules_from_package(package) + local_installed, version = PluginsManager._check_plugin_installed( + package_name) + if local_installed and module_version != version: + return False, version + elif not local_installed: + return False, version + return True, module_version + else: + return PluginsManager._check_plugin_installed(package) + @staticmethod + def _check_plugin_installed(package, verified_version=None): from pip._internal.utils.packaging import get_requirement, specifiers req = get_requirement(package) @@ -656,11 +738,15 @@ class PluginsManager(object): if not installed_valid_version: installed = False break + except KeyError: version = '' installed = False - return installed, version + if installed and verified_version is not None and verified_version != version: + return False, verified_version + else: + return installed, version @staticmethod def pip_command( @@ -675,6 +761,9 @@ class PluginsManager(object): such as ['-r', 'requirements'] Returns: + status_code: The pip command status code, 0 if success, else is failed + options: parsed option from system args by pip command + args: the unknown args that could be parsed by pip command """ from pip._internal.commands import create_command @@ -702,6 +791,7 @@ class PluginsManager(object): Args: install_args (list): List of arguments passed to `pip install`. index_url (str, optional): The pypi index url. + force_update: If force update on or off """ if len(install_args) == 0: @@ -730,6 +820,16 @@ class PluginsManager(object): return status_code, install_args def parse_args_info(self, args: List[str], options): + """ + parse arguments input info + Args: + args: the list of args from pip command output + options: the options that parsed from system args by pip command method + + Returns: + installed_package: generate installed package info in order to store in the file + the info includes: name, url and desc of the package + """ installed_package = [] # the case of install with requirements @@ -781,6 +881,15 @@ class PluginsManager(object): def uninstall_plugins(self, uninstall_args: Union[str, List], is_yes=False): + """ + uninstall plugins + Args: + uninstall_args: args used to uninstall by pip command + is_yes: force yes without verified + + Returns: status code, and uninstall args + + """ if is_yes is not None: uninstall_args += ['-y'] @@ -862,6 +971,7 @@ class PluginsManager(object): show_all: show installed and official supported if True, else only those installed Returns: + local_plugins_info: show the list of plugins info """ local_plugins_info = self._get_plugins_from_file() @@ -901,6 +1011,7 @@ class PluginsManager(object): override: Override the file by the list if True, else only update. Returns: + local_plugins_info_json: the json version of updated plugins info """ local_plugins_info = self._get_plugins_from_file() @@ -921,12 +1032,12 @@ class PluginsManager(object): self, package_names: Union[str, list], ): - """ - + """remove the plugins from file Args: package_names: package name Returns: + local_plugins_info_json: the json version of updated plugins info """ local_plugins_info = self._get_plugins_from_file() @@ -1012,4 +1123,5 @@ class EnvsManager(object): if __name__ == '__main__': install_requirements_by_files(['adaseq']) - import_name = get_modules_from_package('pai-easycv') + import_name, import_version, package_name = get_modules_from_package( + 'pai-easycv')