diff --git a/modelscope/msdatasets/download/download_manager.py b/modelscope/msdatasets/download/download_manager.py index bc75641d..06ec0d26 100644 --- a/modelscope/msdatasets/download/download_manager.py +++ b/modelscope/msdatasets/download/download_manager.py @@ -24,7 +24,6 @@ class DataDownloadManager(DownloadManager): url_or_filename = str(url_or_filename) oss_utilities = OssUtilities( - oss_config=download_config.oss_config, dataset_name=download_config.dataset_name, namespace=download_config.namespace, revision=download_config.version) @@ -56,7 +55,6 @@ class DataStreamingDownloadManager(StreamingDownloadManager): def _download(self, url_or_filename: str) -> str: url_or_filename = str(url_or_filename) oss_utilities = OssUtilities( - oss_config=self.download_config.oss_config, dataset_name=self.download_config.dataset_name, namespace=self.download_config.namespace, revision=self.download_config.version) diff --git a/modelscope/msdatasets/utils/oss_utils.py b/modelscope/msdatasets/utils/oss_utils.py index b1457f4b..bff846fd 100644 --- a/modelscope/msdatasets/utils/oss_utils.py +++ b/modelscope/msdatasets/utils/oss_utils.py @@ -3,8 +3,12 @@ from __future__ import print_function import multiprocessing import os +import threading +import oss2 from datasets.utils.file_utils import hash_url_to_filename +from oss2 import CredentialsProvider +from oss2.credentials import Credentials from modelscope.hub.api import HubApi from modelscope.msdatasets.download.download_config import DataDownloadConfig @@ -23,14 +27,69 @@ BACK_DIR = 'BackupDir' DIR = 'Dir' -class OssUtilities: +class CredentialProviderWrapper(CredentialsProvider): + """ + A custom credentials provider for oss2 that fetches temporary credentials + """ - def __init__(self, oss_config, dataset_name, namespace, revision): - self._do_init(oss_config=oss_config) + def __init__(self, api: HubApi, dataset_name: str, namespace: str, + revision: str): + """ + Initializes the CredentialProviderWrapper with dataset information. + Args: + dataset_name (str): The name of the dataset. + namespace (str): The namespace of the dataset. + revision (str): The revision of the dataset. + """ + self.api = api self.dataset_name = dataset_name self.namespace = namespace self.revision = revision + self._lock = threading.Lock() + + def get_credentials(self): + """ + oss2 SDK will call this method automatically when it finds the token is expired or needs authentication. + """ + with self._lock: + oss_config = self.api.get_dataset_access_config_session( + dataset_name=self.dataset_name, + namespace=self.namespace, + check_cookie=True, + revision=self.revision) + + return Credentials( + access_key_id=oss_config[ACCESS_ID], + access_key_secret=oss_config[ACCESS_SECRET], + security_token=oss_config[SECURITY_TOKEN], + ) + + +class OssUtilities: + """ + A utility class for handling Alibaba Cloud OSS operations such as upload and download. + """ + + def __init__(self, dataset_name, namespace, revision): + """ + Initializes the OssUtilities with the given OSS configuration and dataset information. + """ + self.dataset_name = dataset_name + self.namespace = namespace + self.revision = revision + + self.api = HubApi() + oss_config = self.api.get_dataset_access_config_session( + dataset_name=self.dataset_name, + namespace=self.namespace, + check_cookie=True, + revision=self.revision) + + if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True': + self.endpoint = DEFAULT_DATA_ACCELERATION_ENDPOINT + else: + self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" self.resumable_store_root_path = os.path.join(MS_CACHE_HOME, 'tmp/resumable_store') @@ -39,38 +98,27 @@ class OssUtilities: self.multipart_threshold = 50 * 1024 * 1024 self.max_retries = 3 - import oss2 self.resumable_store_download = oss2.ResumableDownloadStore( root=self.resumable_store_root_path) self.resumable_store_upload = oss2.ResumableStore( root=self.resumable_store_root_path) - self.api = HubApi() - def _do_init(self, oss_config): - import oss2 - - self.key = oss_config[ACCESS_ID] - self.secret = oss_config[ACCESS_SECRET] - self.token = oss_config[SECURITY_TOKEN] - if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True': - self.endpoint = DEFAULT_DATA_ACCELERATION_ENDPOINT - else: - self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" - self.bucket_name = oss_config[BUCKET] - auth = oss2.StsAuth(self.key, self.secret, self.token) - self.bucket = oss2.Bucket( - auth, self.endpoint, self.bucket_name, connect_timeout=120) - self.oss_dir = oss_config[DIR] - self.oss_backup_dir = oss_config[BACK_DIR] - - def _reload_sts(self): - logger.info('Reloading sts token automatically.') - oss_config_refresh = self.api.get_dataset_access_config_session( + credential_provider = CredentialProviderWrapper( + api=self.api, dataset_name=self.dataset_name, namespace=self.namespace, - check_cookie=True, revision=self.revision) - self._do_init(oss_config_refresh) + auth = oss2.ProviderAuthV4(credential_provider) + + self.bucket_name = oss_config[BUCKET] + self.bucket = oss2.Bucket( + auth=auth, + endpoint=self.endpoint, + bucket_name=self.bucket_name, + region=oss_config['Region'].lstrip('oss-'), + ) + self.oss_dir = oss_config[DIR] + self.oss_backup_dir = oss_config[BACK_DIR] @staticmethod def _percentage(consumed_bytes, total_bytes): @@ -79,8 +127,17 @@ class OssUtilities: print('\r{0}% '.format(rate), end='', flush=True) def download(self, oss_file_name: str, - download_config: DataDownloadConfig): - import oss2 + download_config: DataDownloadConfig) -> str: + """ + Downloads a file from OSS to the local cache. + + Args: + oss_file_name (str): The name of the file in OSS to download. + download_config (DataDownloadConfig): Configuration for the download process. + + Returns: + str: The local path to the downloaded file. + """ cache_dir = download_config.cache_dir candidate_key = os.path.join(self.oss_dir, oss_file_name) candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name) @@ -95,7 +152,6 @@ class OssUtilities: retry_count = 0 while True: try: - retry_count += 1 # big_data is True when the dataset contains large number of objects if big_data: file_oss_key = candidate_key @@ -108,9 +164,9 @@ class OssUtilities: if download_config.force_download or not os.path.exists( local_path): oss2.resumable_download( - self.bucket, - file_oss_key, - local_path, + bucket=self.bucket, + key=file_oss_key, + filename=local_path, store=self.resumable_store_download, multiget_threshold=self.multipart_threshold, part_size=self.part_size, @@ -118,10 +174,13 @@ class OssUtilities: num_threads=self.num_threads) break except Exception as e: - if e.__dict__.get('status') == 403: - self._reload_sts() + logger.warning( + f'Error downloading {oss_file_name}: {e}, trying again...') + retry_count += 1 if retry_count >= self.max_retries: - logger.warning(f'Failed to download {oss_file_name}') + logger.error( + f'Failed to download {oss_file_name} due to exceeded retries.' + ) raise e return local_path @@ -129,7 +188,18 @@ class OssUtilities: def upload(self, oss_object_name: str, local_file_path: str, indicate_individual_progress: bool, upload_mode: UploadMode) -> str: - import oss2 + """ + Uploads a local file to OSS. + + Args: + oss_object_name (str): The name of the object in OSS. + local_file_path (str): The local file path to upload. + indicate_individual_progress (bool): Whether to show individual progress. + upload_mode (UploadMode): The upload mode (e.g., OVERWRITE, APPEND). + + Returns: + str: The OSS object key where the file is uploaded. + """ retry_count = 0 object_key = os.path.join(self.oss_dir, oss_object_name) @@ -140,7 +210,6 @@ class OssUtilities: while True: try: - retry_count += 1 exist = self.bucket.object_exists(object_key) if upload_mode == UploadMode.APPEND and exist: logger.info( @@ -149,9 +218,9 @@ class OssUtilities: break oss2.resumable_upload( - self.bucket, - object_key, - local_file_path, + bucket=self.bucket, + key=object_key, + filename=local_file_path, store=self.resumable_store_upload, multipart_threshold=self.multipart_threshold, part_size=self.part_size, @@ -159,9 +228,13 @@ class OssUtilities: num_threads=self.num_threads) break except Exception as e: - if e.__dict__.get('status') == 403: - self._reload_sts() + logger.warning( + f'Error uploading {oss_object_name}: {e}, trying again...') + retry_count += 1 if retry_count >= self.max_retries: - raise + logger.error( + f'Failed to upload {oss_object_name} due to exceeded retries.' + ) + raise e return object_key diff --git a/modelscope/msdatasets/utils/upload_utils.py b/modelscope/msdatasets/utils/upload_utils.py index a176be6a..5c86cbaa 100644 --- a/modelscope/msdatasets/utils/upload_utils.py +++ b/modelscope/msdatasets/utils/upload_utils.py @@ -13,18 +13,9 @@ class DatasetUploadManager(object): def __init__(self, dataset_name: str, namespace: str, version: str): from modelscope.hub.api import HubApi - _hub_api = HubApi() - _oss_config = _hub_api.get_dataset_access_config_session( - dataset_name=dataset_name, - namespace=namespace, - check_cookie=False, - revision=version) self.oss_utilities = OssUtilities( - oss_config=_oss_config, - dataset_name=dataset_name, - namespace=namespace, - revision=version) + dataset_name=dataset_name, namespace=namespace, revision=version) def upload(self, object_name: str, local_file_path: str, upload_mode: UploadMode) -> str: