From 759c56465951cde5470bcb56a5b41ef95bee711a Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Fri, 12 Jul 2024 11:44:22 +0800 Subject: [PATCH] refactor code --- modelscope/hub/dataset_download.py | 62 +++-------------------------- modelscope/hub/file_download.py | 56 ++++++++++++++------------ modelscope/hub/snapshot_download.py | 32 ++------------- 3 files changed, 39 insertions(+), 111 deletions(-) diff --git a/modelscope/hub/dataset_download.py b/modelscope/hub/dataset_download.py index c857d2d9..413b900b 100644 --- a/modelscope/hub/dataset_download.py +++ b/modelscope/hub/dataset_download.py @@ -11,12 +11,8 @@ from modelscope.hub.errors import NotExistError from modelscope.utils.constant import DEFAULT_DATASET_REVISION from modelscope.utils.file_utils import get_dataset_cache_root from modelscope.utils.logger import get_logger -from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS, - MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB) -from .file_download import (create_temporary_directory_and_cache, - http_get_model_file, parallel_download) -from .utils.utils import (file_integrity_validation, - model_id_to_group_owner_name) +from .file_download import create_temporary_directory_and_cache, download_file +from .utils.utils import model_id_to_group_owner_name logger = get_logger() @@ -138,32 +134,8 @@ def dataset_file_download( namespace=group_or_owner, revision=revision) - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta_to_download[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. - parallel_download( - url, - temporary_cache_dir, - file_meta_to_download['Name'], # TODO use Path - headers=headers, - cookies=None if cookies is None else cookies.get_dict(), - file_size=file_meta_to_download['Size']) - else: - http_get_model_file( - url, - temporary_cache_dir, - file_meta_to_download['Name'], - file_size=file_meta_to_download['Size'], - headers=headers, - cookies=cookies) - - # check file integrity - temp_file = os.path.join(temporary_cache_dir, - file_meta_to_download['Name']) - if FILE_HASH in file_meta_to_download: - file_integrity_validation(temp_file, - file_meta_to_download[FILE_HASH]) - # put file into to cache - return cache.put_file(file_meta_to_download, temp_file) + return download_file(url, file_meta_to_download, temporary_cache_dir, + cache, headers, cookies) def dataset_snapshot_download( @@ -299,30 +271,8 @@ def dataset_snapshot_download( namespace=group_or_owner, revision=revision) - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. - parallel_download( - url, - temporary_cache_dir, - file_meta['Name'], # TODO use Path - headers=headers, - cookies=None if cookies is None else cookies.get_dict(), - file_size=file_meta['Size']) - else: - http_get_model_file( - url, - temporary_cache_dir, - file_meta['Name'], - file_size=file_meta['Size'], - headers=headers, - cookies=cookies) - - # check file integrity - temp_file = os.path.join(temporary_cache_dir, file_meta['Name']) - if FILE_HASH in file_meta: - file_integrity_validation(temp_file, file_meta[FILE_HASH]) - # put file into to cache - cache.put_file(file_meta, temp_file) + download_file(url, file_meta, temporary_cache_dir, cache, headers, + cookies) cache.save_model_version(revision_info=revision) return os.path.join(cache.get_root_location()) diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index b2f69f64..dce4a7aa 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -134,32 +134,8 @@ def model_file_download( # we need to download again url_to_download = get_file_download_url(model_id, file_path, revision) - - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: - parallel_download( - url_to_download, - temporary_cache_dir, - file_path, - headers=headers, - cookies=None if cookies is None else cookies.get_dict(), - file_size=file_to_download_info['Size']) - else: - http_get_model_file( - url_to_download, - temporary_cache_dir, - file_path, - file_size=file_to_download_info['Size'], - headers=headers, - cookies=None if cookies is None else cookies.get_dict()) - - temp_file_path = os.path.join(temporary_cache_dir, file_path) - # for download with commit we can't get Sha256 - if file_to_download_info[FILE_HASH] is not None: - file_integrity_validation(temp_file_path, - file_to_download_info[FILE_HASH]) - return cache.put_file(file_to_download_info, - os.path.join(temporary_cache_dir, file_path)) + return download_file(url_to_download, file_to_download_info, + temporary_cache_dir, cache, headers, cookies) def create_temporary_directory_and_cache(model_id: str, local_dir: str, @@ -454,3 +430,31 @@ def http_get_file( logger.error(msg) raise FileDownloadError(msg) os.replace(temp_file.name, os.path.join(local_dir, file_name)) + + +def download_file(url, file_meta, temporary_cache_dir, cache, headers, + cookies): + if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ + 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. + parallel_download( + url, + temporary_cache_dir, + file_meta['Path'], + headers=headers, + cookies=None if cookies is None else cookies.get_dict(), + file_size=file_meta['Size']) + else: + http_get_model_file( + url, + temporary_cache_dir, + file_meta['Path'], + file_size=file_meta['Size'], + headers=headers, + cookies=cookies) + + # check file integrity + temp_file = os.path.join(temporary_cache_dir, file_meta['Path']) + if FILE_HASH in file_meta: + file_integrity_validation(temp_file, file_meta[FILE_HASH]) + # put file into to cache + return cache.put_file(file_meta, temp_file) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index f6089e60..1ede1521 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -11,12 +11,8 @@ from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.utils.constant import DEFAULT_MODEL_REVISION from modelscope.utils.file_utils import get_model_cache_root from modelscope.utils.logger import get_logger -from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS, - MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB) from .file_download import (create_temporary_directory_and_cache, - get_file_download_url, http_get_model_file, - parallel_download) -from .utils.utils import file_integrity_validation + download_file, get_file_download_url) logger = get_logger() @@ -157,30 +153,8 @@ def snapshot_download( file_path=model_file['Path'], revision=revision) - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: - parallel_download( - url, - temporary_cache_dir, - model_file['Name'], - headers=headers, - cookies=None if cookies is None else cookies.get_dict(), - file_size=model_file['Size']) - else: - http_get_model_file( - url, - temporary_cache_dir, - model_file['Name'], - file_size=model_file['Size'], - headers=headers, - cookies=cookies) - - # check file integrity - temp_file = os.path.join(temporary_cache_dir, model_file['Name']) - if FILE_HASH in model_file: - file_integrity_validation(temp_file, model_file[FILE_HASH]) - # put file into to cache - cache.put_file(model_file, temp_file) + download_file(url, model_file, temporary_cache_dir, cache, headers, + cookies) cache.save_model_version(revision_info=revision_detail) return os.path.join(cache.get_root_location())