mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[Fix]oss utils sts auth expire issue (#1589)
This commit is contained in:
@@ -24,7 +24,6 @@ class DataDownloadManager(DownloadManager):
|
||||
url_or_filename = str(url_or_filename)
|
||||
|
||||
oss_utilities = OssUtilities(
|
||||
oss_config=download_config.oss_config,
|
||||
dataset_name=download_config.dataset_name,
|
||||
namespace=download_config.namespace,
|
||||
revision=download_config.version)
|
||||
@@ -56,7 +55,6 @@ class DataStreamingDownloadManager(StreamingDownloadManager):
|
||||
def _download(self, url_or_filename: str) -> str:
|
||||
url_or_filename = str(url_or_filename)
|
||||
oss_utilities = OssUtilities(
|
||||
oss_config=self.download_config.oss_config,
|
||||
dataset_name=self.download_config.dataset_name,
|
||||
namespace=self.download_config.namespace,
|
||||
revision=self.download_config.version)
|
||||
|
||||
@@ -3,8 +3,12 @@
|
||||
from __future__ import print_function
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
|
||||
import oss2
|
||||
from datasets.utils.file_utils import hash_url_to_filename
|
||||
from oss2 import CredentialsProvider
|
||||
from oss2.credentials import Credentials
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.msdatasets.download.download_config import DataDownloadConfig
|
||||
@@ -23,14 +27,69 @@ BACK_DIR = 'BackupDir'
|
||||
DIR = 'Dir'
|
||||
|
||||
|
||||
class OssUtilities:
|
||||
class CredentialProviderWrapper(CredentialsProvider):
|
||||
"""
|
||||
A custom credentials provider for oss2 that fetches temporary credentials
|
||||
"""
|
||||
|
||||
def __init__(self, oss_config, dataset_name, namespace, revision):
|
||||
self._do_init(oss_config=oss_config)
|
||||
def __init__(self, api: HubApi, dataset_name: str, namespace: str,
|
||||
revision: str):
|
||||
"""
|
||||
Initializes the CredentialProviderWrapper with dataset information.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
namespace (str): The namespace of the dataset.
|
||||
revision (str): The revision of the dataset.
|
||||
"""
|
||||
self.api = api
|
||||
self.dataset_name = dataset_name
|
||||
self.namespace = namespace
|
||||
self.revision = revision
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_credentials(self):
|
||||
"""
|
||||
oss2 SDK will call this method automatically when it finds the token is expired or needs authentication.
|
||||
"""
|
||||
with self._lock:
|
||||
oss_config = self.api.get_dataset_access_config_session(
|
||||
dataset_name=self.dataset_name,
|
||||
namespace=self.namespace,
|
||||
check_cookie=True,
|
||||
revision=self.revision)
|
||||
|
||||
return Credentials(
|
||||
access_key_id=oss_config[ACCESS_ID],
|
||||
access_key_secret=oss_config[ACCESS_SECRET],
|
||||
security_token=oss_config[SECURITY_TOKEN],
|
||||
)
|
||||
|
||||
|
||||
class OssUtilities:
|
||||
"""
|
||||
A utility class for handling Alibaba Cloud OSS operations such as upload and download.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_name, namespace, revision):
|
||||
"""
|
||||
Initializes the OssUtilities with the given OSS configuration and dataset information.
|
||||
"""
|
||||
self.dataset_name = dataset_name
|
||||
self.namespace = namespace
|
||||
self.revision = revision
|
||||
|
||||
self.api = HubApi()
|
||||
oss_config = self.api.get_dataset_access_config_session(
|
||||
dataset_name=self.dataset_name,
|
||||
namespace=self.namespace,
|
||||
check_cookie=True,
|
||||
revision=self.revision)
|
||||
|
||||
if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True':
|
||||
self.endpoint = DEFAULT_DATA_ACCELERATION_ENDPOINT
|
||||
else:
|
||||
self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com"
|
||||
|
||||
self.resumable_store_root_path = os.path.join(MS_CACHE_HOME,
|
||||
'tmp/resumable_store')
|
||||
@@ -39,38 +98,27 @@ class OssUtilities:
|
||||
self.multipart_threshold = 50 * 1024 * 1024
|
||||
self.max_retries = 3
|
||||
|
||||
import oss2
|
||||
self.resumable_store_download = oss2.ResumableDownloadStore(
|
||||
root=self.resumable_store_root_path)
|
||||
self.resumable_store_upload = oss2.ResumableStore(
|
||||
root=self.resumable_store_root_path)
|
||||
self.api = HubApi()
|
||||
|
||||
def _do_init(self, oss_config):
|
||||
import oss2
|
||||
|
||||
self.key = oss_config[ACCESS_ID]
|
||||
self.secret = oss_config[ACCESS_SECRET]
|
||||
self.token = oss_config[SECURITY_TOKEN]
|
||||
if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True':
|
||||
self.endpoint = DEFAULT_DATA_ACCELERATION_ENDPOINT
|
||||
else:
|
||||
self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com"
|
||||
self.bucket_name = oss_config[BUCKET]
|
||||
auth = oss2.StsAuth(self.key, self.secret, self.token)
|
||||
self.bucket = oss2.Bucket(
|
||||
auth, self.endpoint, self.bucket_name, connect_timeout=120)
|
||||
self.oss_dir = oss_config[DIR]
|
||||
self.oss_backup_dir = oss_config[BACK_DIR]
|
||||
|
||||
def _reload_sts(self):
|
||||
logger.info('Reloading sts token automatically.')
|
||||
oss_config_refresh = self.api.get_dataset_access_config_session(
|
||||
credential_provider = CredentialProviderWrapper(
|
||||
api=self.api,
|
||||
dataset_name=self.dataset_name,
|
||||
namespace=self.namespace,
|
||||
check_cookie=True,
|
||||
revision=self.revision)
|
||||
self._do_init(oss_config_refresh)
|
||||
auth = oss2.ProviderAuthV4(credential_provider)
|
||||
|
||||
self.bucket_name = oss_config[BUCKET]
|
||||
self.bucket = oss2.Bucket(
|
||||
auth=auth,
|
||||
endpoint=self.endpoint,
|
||||
bucket_name=self.bucket_name,
|
||||
region=oss_config['Region'].lstrip('oss-'),
|
||||
)
|
||||
self.oss_dir = oss_config[DIR]
|
||||
self.oss_backup_dir = oss_config[BACK_DIR]
|
||||
|
||||
@staticmethod
|
||||
def _percentage(consumed_bytes, total_bytes):
|
||||
@@ -79,8 +127,17 @@ class OssUtilities:
|
||||
print('\r{0}% '.format(rate), end='', flush=True)
|
||||
|
||||
def download(self, oss_file_name: str,
|
||||
download_config: DataDownloadConfig):
|
||||
import oss2
|
||||
download_config: DataDownloadConfig) -> str:
|
||||
"""
|
||||
Downloads a file from OSS to the local cache.
|
||||
|
||||
Args:
|
||||
oss_file_name (str): The name of the file in OSS to download.
|
||||
download_config (DataDownloadConfig): Configuration for the download process.
|
||||
|
||||
Returns:
|
||||
str: The local path to the downloaded file.
|
||||
"""
|
||||
cache_dir = download_config.cache_dir
|
||||
candidate_key = os.path.join(self.oss_dir, oss_file_name)
|
||||
candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name)
|
||||
@@ -95,7 +152,6 @@ class OssUtilities:
|
||||
retry_count = 0
|
||||
while True:
|
||||
try:
|
||||
retry_count += 1
|
||||
# big_data is True when the dataset contains large number of objects
|
||||
if big_data:
|
||||
file_oss_key = candidate_key
|
||||
@@ -108,9 +164,9 @@ class OssUtilities:
|
||||
if download_config.force_download or not os.path.exists(
|
||||
local_path):
|
||||
oss2.resumable_download(
|
||||
self.bucket,
|
||||
file_oss_key,
|
||||
local_path,
|
||||
bucket=self.bucket,
|
||||
key=file_oss_key,
|
||||
filename=local_path,
|
||||
store=self.resumable_store_download,
|
||||
multiget_threshold=self.multipart_threshold,
|
||||
part_size=self.part_size,
|
||||
@@ -118,10 +174,13 @@ class OssUtilities:
|
||||
num_threads=self.num_threads)
|
||||
break
|
||||
except Exception as e:
|
||||
if e.__dict__.get('status') == 403:
|
||||
self._reload_sts()
|
||||
logger.warning(
|
||||
f'Error downloading {oss_file_name}: {e}, trying again...')
|
||||
retry_count += 1
|
||||
if retry_count >= self.max_retries:
|
||||
logger.warning(f'Failed to download {oss_file_name}')
|
||||
logger.error(
|
||||
f'Failed to download {oss_file_name} due to exceeded retries.'
|
||||
)
|
||||
raise e
|
||||
|
||||
return local_path
|
||||
@@ -129,7 +188,18 @@ class OssUtilities:
|
||||
def upload(self, oss_object_name: str, local_file_path: str,
|
||||
indicate_individual_progress: bool,
|
||||
upload_mode: UploadMode) -> str:
|
||||
import oss2
|
||||
"""
|
||||
Uploads a local file to OSS.
|
||||
|
||||
Args:
|
||||
oss_object_name (str): The name of the object in OSS.
|
||||
local_file_path (str): The local file path to upload.
|
||||
indicate_individual_progress (bool): Whether to show individual progress.
|
||||
upload_mode (UploadMode): The upload mode (e.g., OVERWRITE, APPEND).
|
||||
|
||||
Returns:
|
||||
str: The OSS object key where the file is uploaded.
|
||||
"""
|
||||
retry_count = 0
|
||||
object_key = os.path.join(self.oss_dir, oss_object_name)
|
||||
|
||||
@@ -140,7 +210,6 @@ class OssUtilities:
|
||||
|
||||
while True:
|
||||
try:
|
||||
retry_count += 1
|
||||
exist = self.bucket.object_exists(object_key)
|
||||
if upload_mode == UploadMode.APPEND and exist:
|
||||
logger.info(
|
||||
@@ -149,9 +218,9 @@ class OssUtilities:
|
||||
break
|
||||
|
||||
oss2.resumable_upload(
|
||||
self.bucket,
|
||||
object_key,
|
||||
local_file_path,
|
||||
bucket=self.bucket,
|
||||
key=object_key,
|
||||
filename=local_file_path,
|
||||
store=self.resumable_store_upload,
|
||||
multipart_threshold=self.multipart_threshold,
|
||||
part_size=self.part_size,
|
||||
@@ -159,9 +228,13 @@ class OssUtilities:
|
||||
num_threads=self.num_threads)
|
||||
break
|
||||
except Exception as e:
|
||||
if e.__dict__.get('status') == 403:
|
||||
self._reload_sts()
|
||||
logger.warning(
|
||||
f'Error uploading {oss_object_name}: {e}, trying again...')
|
||||
retry_count += 1
|
||||
if retry_count >= self.max_retries:
|
||||
raise
|
||||
logger.error(
|
||||
f'Failed to upload {oss_object_name} due to exceeded retries.'
|
||||
)
|
||||
raise e
|
||||
|
||||
return object_key
|
||||
|
||||
@@ -13,18 +13,9 @@ class DatasetUploadManager(object):
|
||||
|
||||
def __init__(self, dataset_name: str, namespace: str, version: str):
|
||||
from modelscope.hub.api import HubApi
|
||||
_hub_api = HubApi()
|
||||
_oss_config = _hub_api.get_dataset_access_config_session(
|
||||
dataset_name=dataset_name,
|
||||
namespace=namespace,
|
||||
check_cookie=False,
|
||||
revision=version)
|
||||
|
||||
self.oss_utilities = OssUtilities(
|
||||
oss_config=_oss_config,
|
||||
dataset_name=dataset_name,
|
||||
namespace=namespace,
|
||||
revision=version)
|
||||
dataset_name=dataset_name, namespace=namespace, revision=version)
|
||||
|
||||
def upload(self, object_name: str, local_file_path: str,
|
||||
upload_mode: UploadMode) -> str:
|
||||
|
||||
Reference in New Issue
Block a user