[Fix] lazy import oss2 (#1649)

* lazy import oss2

* fix lint
This commit is contained in:
Xingjun.Wang
2026-03-17 16:03:42 +08:00
committed by GitHub
parent a420e68953
commit 832d88cb02
4 changed files with 44 additions and 39 deletions

View File

@@ -405,7 +405,8 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
"""
version_args = (
f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} '
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} {self.args.optimum_version}'
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} '
f'{self.args.optimum_version}'
f'{self.args.flashattn_version}')
with open('docker/Dockerfile.ubuntu', 'r') as f:
content = f.read()

View File

@@ -21,7 +21,6 @@ from multiprocessing.pool import ThreadPool as Pool
import imageio
import json
import numpy as np
import oss2 as oss
import requests
import skvideo.io
import torch
@@ -169,6 +168,8 @@ def DOWNLOAD_TO_CACHE(oss_key,
def parse_oss_url(path):
import oss2 as oss
if path.startswith('oss://'):
path = path[len('oss://'):]

View File

@@ -22,7 +22,6 @@ from multiprocessing.pool import ThreadPool as Pool
import imageio
import json
import numpy as np
import oss2 as oss
import requests
import skvideo.io
import torch
@@ -81,6 +80,8 @@ def setup_seed(seed):
def parse_oss_url(path):
import oss2 as oss
if path.startswith('oss://'):
path = path[len('oss://'):]

View File

@@ -5,10 +5,7 @@ import multiprocessing
import os
import threading
import oss2
from datasets.utils.file_utils import hash_url_to_filename
from oss2 import CredentialsProvider
from oss2.credentials import Credentials
from modelscope.hub.api import HubApi
from modelscope.msdatasets.download.download_config import DataDownloadConfig
@@ -27,21 +24,19 @@ BACK_DIR = 'BackupDir'
DIR = 'Dir'
class CredentialProviderWrapper(CredentialsProvider):
"""
A custom credentials provider for oss2 that fetches temporary credentials
"""
def _create_credential_provider(api, dataset_name, namespace, revision):
"""Create a credentials provider for oss2 with lazy import.
def __init__(self, api: HubApi, dataset_name: str, namespace: str,
revision: str):
Returns an instance that subclasses oss2.CredentialsProvider so that the
oss2 SDK can call ``get_credentials()`` automatically when the token
expires or authentication is needed.
"""
Initializes the CredentialProviderWrapper with dataset information.
from oss2 import CredentialsProvider
from oss2.credentials import Credentials
Args:
dataset_name (str): The name of the dataset.
namespace (str): The namespace of the dataset.
revision (str): The revision of the dataset.
"""
class _CredentialProviderWrapper(CredentialsProvider):
def __init__(self):
self.api = api
self.dataset_name = dataset_name
self.namespace = namespace
@@ -49,9 +44,8 @@ class CredentialProviderWrapper(CredentialsProvider):
self._lock = threading.Lock()
def get_credentials(self):
"""
oss2 SDK will call this method automatically when it finds the token is expired or needs authentication.
"""
"""oss2 SDK will call this method automatically when it finds
the token is expired or needs authentication."""
with self._lock:
oss_config = self.api.get_dataset_access_config_session(
dataset_name=self.dataset_name,
@@ -65,6 +59,8 @@ class CredentialProviderWrapper(CredentialsProvider):
security_token=oss_config[SECURITY_TOKEN],
)
return _CredentialProviderWrapper()
class OssUtilities:
"""
@@ -98,12 +94,14 @@ class OssUtilities:
self.multipart_threshold = 50 * 1024 * 1024
self.max_retries = 3
import oss2
self.resumable_store_download = oss2.ResumableDownloadStore(
root=self.resumable_store_root_path)
self.resumable_store_upload = oss2.ResumableStore(
root=self.resumable_store_root_path)
credential_provider = CredentialProviderWrapper(
credential_provider = _create_credential_provider(
api=self.api,
dataset_name=self.dataset_name,
namespace=self.namespace,
@@ -150,6 +148,8 @@ class OssUtilities:
big_data = args_dict.get(MetaDataFields.ARGS_BIG_DATA)
retry_count = 0
import oss2
while True:
try:
# big_data is True when the dataset contains large number of objects
@@ -208,6 +208,8 @@ class OssUtilities:
else:
progress_callback = None
import oss2
while True:
try:
exist = self.bucket.object_exists(object_key)