refactor code

This commit is contained in:
mulin.lyh
2024-07-12 11:44:22 +08:00
parent eea7575ab5
commit 759c564659
3 changed files with 39 additions and 111 deletions

View File

@@ -11,12 +11,8 @@ from modelscope.hub.errors import NotExistError
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
from modelscope.utils.file_utils import get_dataset_cache_root
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (create_temporary_directory_and_cache,
http_get_model_file, parallel_download)
from .utils.utils import (file_integrity_validation,
model_id_to_group_owner_name)
from .file_download import create_temporary_directory_and_cache, download_file
from .utils.utils import model_id_to_group_owner_name
logger = get_logger()
@@ -138,32 +134,8 @@ def dataset_file_download(
namespace=group_or_owner,
revision=revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta_to_download[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
parallel_download(
url,
temporary_cache_dir,
file_meta_to_download['Name'], # TODO use Path
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta_to_download['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
file_meta_to_download['Name'],
file_size=file_meta_to_download['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir,
file_meta_to_download['Name'])
if FILE_HASH in file_meta_to_download:
file_integrity_validation(temp_file,
file_meta_to_download[FILE_HASH])
# put file into to cache
return cache.put_file(file_meta_to_download, temp_file)
return download_file(url, file_meta_to_download, temporary_cache_dir,
cache, headers, cookies)
def dataset_snapshot_download(
@@ -299,30 +271,8 @@ def dataset_snapshot_download(
namespace=group_or_owner,
revision=revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
parallel_download(
url,
temporary_cache_dir,
file_meta['Name'], # TODO use Path
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
file_meta['Name'],
file_size=file_meta['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, file_meta['Name'])
if FILE_HASH in file_meta:
file_integrity_validation(temp_file, file_meta[FILE_HASH])
# put file into to cache
cache.put_file(file_meta, temp_file)
download_file(url, file_meta, temporary_cache_dir, cache, headers,
cookies)
cache.save_model_version(revision_info=revision)
return os.path.join(cache.get_root_location())

View File

@@ -134,32 +134,8 @@ def model_file_download(
# we need to download again
url_to_download = get_file_download_url(model_id, file_path, revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_to_download_info[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
parallel_download(
url_to_download,
temporary_cache_dir,
file_path,
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_to_download_info['Size'])
else:
http_get_model_file(
url_to_download,
temporary_cache_dir,
file_path,
file_size=file_to_download_info['Size'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
temp_file_path = os.path.join(temporary_cache_dir, file_path)
# for download with commit we can't get Sha256
if file_to_download_info[FILE_HASH] is not None:
file_integrity_validation(temp_file_path,
file_to_download_info[FILE_HASH])
return cache.put_file(file_to_download_info,
os.path.join(temporary_cache_dir, file_path))
return download_file(url_to_download, file_to_download_info,
temporary_cache_dir, cache, headers, cookies)
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
@@ -454,3 +430,31 @@ def http_get_file(
logger.error(msg)
raise FileDownloadError(msg)
os.replace(temp_file.name, os.path.join(local_dir, file_name))
def download_file(url, file_meta, temporary_cache_dir, cache, headers,
cookies):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
parallel_download(
url,
temporary_cache_dir,
file_meta['Path'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
file_meta['Path'],
file_size=file_meta['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
if FILE_HASH in file_meta:
file_integrity_validation(temp_file, file_meta[FILE_HASH])
# put file into to cache
return cache.put_file(file_meta, temp_file)

View File

@@ -11,12 +11,8 @@ from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (create_temporary_directory_and_cache,
get_file_download_url, http_get_model_file,
parallel_download)
from .utils.utils import file_integrity_validation
download_file, get_file_download_url)
logger = get_logger()
@@ -157,30 +153,8 @@ def snapshot_download(
file_path=model_file['Path'],
revision=revision)
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1:
parallel_download(
url,
temporary_cache_dir,
model_file['Name'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=model_file['Size'])
else:
http_get_model_file(
url,
temporary_cache_dir,
model_file['Name'],
file_size=model_file['Size'],
headers=headers,
cookies=cookies)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, model_file['Name'])
if FILE_HASH in model_file:
file_integrity_validation(temp_file, model_file[FILE_HASH])
# put file into to cache
cache.put_file(model_file, temp_file)
download_file(url, model_file, temporary_cache_dir, cache, headers,
cookies)
cache.save_model_version(revision_info=revision_detail)
return os.path.join(cache.get_root_location())