[Fix] lazy import oss2 (#1649)

* lazy import oss2

* fix lint
This commit is contained in:
Xingjun.Wang
2026-03-17 16:03:42 +08:00
committed by GitHub
parent a420e68953
commit 832d88cb02
4 changed files with 44 additions and 39 deletions

View File

@@ -405,7 +405,8 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
"""
version_args = (
f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} '
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} {self.args.optimum_version}'
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} '
f'{self.args.optimum_version}'
f'{self.args.flashattn_version}')
with open('docker/Dockerfile.ubuntu', 'r') as f:
content = f.read()

View File

@@ -21,7 +21,6 @@ from multiprocessing.pool import ThreadPool as Pool
import imageio
import json
import numpy as np
import oss2 as oss
import requests
import skvideo.io
import torch
@@ -169,6 +168,8 @@ def DOWNLOAD_TO_CACHE(oss_key,
def parse_oss_url(path):
import oss2 as oss
if path.startswith('oss://'):
path = path[len('oss://'):]

View File

@@ -22,7 +22,6 @@ from multiprocessing.pool import ThreadPool as Pool
import imageio
import json
import numpy as np
import oss2 as oss
import requests
import skvideo.io
import torch
@@ -81,6 +80,8 @@ def setup_seed(seed):
def parse_oss_url(path):
import oss2 as oss
if path.startswith('oss://'):
path = path[len('oss://'):]

View File

@@ -5,10 +5,7 @@ import multiprocessing
import os
import threading
import oss2
from datasets.utils.file_utils import hash_url_to_filename
from oss2 import CredentialsProvider
from oss2.credentials import Credentials
from modelscope.hub.api import HubApi
from modelscope.msdatasets.download.download_config import DataDownloadConfig
@@ -27,43 +24,42 @@ BACK_DIR = 'BackupDir'
DIR = 'Dir'
class CredentialProviderWrapper(CredentialsProvider):
"""
A custom credentials provider for oss2 that fetches temporary credentials
def _create_credential_provider(api, dataset_name, namespace, revision):
"""Create a credentials provider for oss2 with lazy import.
Returns an instance that subclasses oss2.CredentialsProvider so that the
oss2 SDK can call ``get_credentials()`` automatically when the token
expires or authentication is needed.
"""
from oss2 import CredentialsProvider
from oss2.credentials import Credentials
def __init__(self, api: HubApi, dataset_name: str, namespace: str,
revision: str):
"""
Initializes the CredentialProviderWrapper with dataset information.
class _CredentialProviderWrapper(CredentialsProvider):
Args:
dataset_name (str): The name of the dataset.
namespace (str): The namespace of the dataset.
revision (str): The revision of the dataset.
"""
self.api = api
self.dataset_name = dataset_name
self.namespace = namespace
self.revision = revision
self._lock = threading.Lock()
def __init__(self):
self.api = api
self.dataset_name = dataset_name
self.namespace = namespace
self.revision = revision
self._lock = threading.Lock()
def get_credentials(self):
"""
oss2 SDK will call this method automatically when it finds the token is expired or needs authentication.
"""
with self._lock:
oss_config = self.api.get_dataset_access_config_session(
dataset_name=self.dataset_name,
namespace=self.namespace,
check_cookie=False,
revision=self.revision)
def get_credentials(self):
"""oss2 SDK will call this method automatically when it finds
the token is expired or needs authentication."""
with self._lock:
oss_config = self.api.get_dataset_access_config_session(
dataset_name=self.dataset_name,
namespace=self.namespace,
check_cookie=False,
revision=self.revision)
return Credentials(
access_key_id=oss_config[ACCESS_ID],
access_key_secret=oss_config[ACCESS_SECRET],
security_token=oss_config[SECURITY_TOKEN],
)
return Credentials(
access_key_id=oss_config[ACCESS_ID],
access_key_secret=oss_config[ACCESS_SECRET],
security_token=oss_config[SECURITY_TOKEN],
)
return _CredentialProviderWrapper()
class OssUtilities:
@@ -98,12 +94,14 @@ class OssUtilities:
self.multipart_threshold = 50 * 1024 * 1024
self.max_retries = 3
import oss2
self.resumable_store_download = oss2.ResumableDownloadStore(
root=self.resumable_store_root_path)
self.resumable_store_upload = oss2.ResumableStore(
root=self.resumable_store_root_path)
credential_provider = CredentialProviderWrapper(
credential_provider = _create_credential_provider(
api=self.api,
dataset_name=self.dataset_name,
namespace=self.namespace,
@@ -150,6 +148,8 @@ class OssUtilities:
big_data = args_dict.get(MetaDataFields.ARGS_BIG_DATA)
retry_count = 0
import oss2
while True:
try:
# big_data is True when the dataset contains large number of objects
@@ -208,6 +208,8 @@ class OssUtilities:
else:
progress_callback = None
import oss2
while True:
try:
exist = self.bucket.object_exists(object_key)