Files
modelscope/modelscope/msdatasets/data_loader/data_loader_manager.py
2025-12-15 20:31:17 +08:00

156 lines
6.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import enum
import os
from abc import ABC, abstractmethod
from datasets import load_dataset as hf_data_loader
from modelscope.hub.api import HubApi
from modelscope.msdatasets.context.dataset_context_config import \
DatasetContextConfig
from modelscope.msdatasets.data_loader.data_loader import OssDownloader
from modelscope.utils.constant import EXTENSIONS_TO_LOAD
from modelscope.utils.logger import get_logger
logger = get_logger()
class LocalDataLoaderType(enum.Enum):
""" Supported data loader types for local dataset: huggingface, PyTorch, Tensorflow """
HF_DATA_LOADER = 'hf_data_loader'
TORCH_DATA_LOADER = 'torch_data_loader'
TF_DATA_LOADER = 'tf_data_loader'
class RemoteDataLoaderType(enum.Enum):
""" Supported data loader types for remote dataset: huggingface, modelscope """
HF_DATA_LOADER = 'hf_data_loader'
MS_DATA_LOADER = 'ms_data_loader'
class DataLoaderManager(ABC):
"""Data loader manager, base class."""
def __init__(self, dataset_context_config: DatasetContextConfig):
self.dataset_context_config = dataset_context_config
@abstractmethod
def load_dataset(self, data_loader_type: enum.Enum):
...
class LocalDataLoaderManager(DataLoaderManager):
"""Data loader manager for loading local data."""
def __init__(self, dataset_context_config: DatasetContextConfig):
super().__init__(dataset_context_config=dataset_context_config)
def load_dataset(self, data_loader_type: enum.Enum):
# Get args from context
dataset_name = self.dataset_context_config.dataset_name
subset_name = self.dataset_context_config.subset_name
version = self.dataset_context_config.version
split = self.dataset_context_config.split
data_dir = self.dataset_context_config.data_dir
data_files = self.dataset_context_config.data_files
cache_root_dir = self.dataset_context_config.cache_root_dir
download_mode = self.dataset_context_config.download_mode
use_streaming = self.dataset_context_config.use_streaming
trust_remote_code = self.dataset_context_config.trust_remote_code
input_config_kwargs = self.dataset_context_config.config_kwargs
# load local single file
if os.path.isfile(dataset_name):
file_ext = os.path.splitext(dataset_name)[1].strip('.')
if file_ext in EXTENSIONS_TO_LOAD:
split = None
data_files = [dataset_name]
dataset_name = EXTENSIONS_TO_LOAD.get(file_ext)
# Select local data loader
# TODO: more loaders to be supported.
if data_loader_type == LocalDataLoaderType.HF_DATA_LOADER:
if trust_remote_code:
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make '
'sure that you can trust the external codes.')
# Build huggingface data loader and return dataset.
return hf_data_loader(
dataset_name,
name=subset_name,
revision=version,
split=split,
data_dir=data_dir,
data_files=data_files,
cache_dir=cache_root_dir,
download_mode=download_mode.value,
streaming=use_streaming,
trust_remote_code=trust_remote_code,
**input_config_kwargs)
raise f'Expected local data loader type: {LocalDataLoaderType.HF_DATA_LOADER.value}.'
class RemoteDataLoaderManager(DataLoaderManager):
"""Data loader manager for loading remote data."""
def __init__(self, dataset_context_config: DatasetContextConfig):
super().__init__(dataset_context_config=dataset_context_config)
self.api = HubApi(token=dataset_context_config.token)
def load_dataset(self, data_loader_type: enum.Enum):
# Get args from context
dataset_name = self.dataset_context_config.dataset_name
namespace = self.dataset_context_config.namespace
subset_name = self.dataset_context_config.subset_name
version = self.dataset_context_config.version
split = self.dataset_context_config.split
data_dir = self.dataset_context_config.data_dir
data_files = self.dataset_context_config.data_files
download_mode_val = self.dataset_context_config.download_mode.value
use_streaming = self.dataset_context_config.use_streaming
input_config_kwargs = self.dataset_context_config.config_kwargs
trust_remote_code = self.dataset_context_config.trust_remote_code
token = self.dataset_context_config.token
# To use the huggingface data loader
if data_loader_type == RemoteDataLoaderType.HF_DATA_LOADER:
if trust_remote_code:
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make '
'sure that you can trust the external codes.')
dataset_ret = hf_data_loader(
dataset_name,
name=subset_name,
revision=version,
split=split,
data_dir=data_dir,
data_files=data_files,
download_mode=download_mode_val,
streaming=use_streaming,
trust_remote_code=trust_remote_code,
token=token,
**input_config_kwargs)
# download statistics
self.api.dataset_download_statistics(
dataset_name=dataset_name,
namespace=namespace,
use_streaming=use_streaming)
return dataset_ret
# To use the modelscope data loader
elif data_loader_type == RemoteDataLoaderType.MS_DATA_LOADER:
oss_downloader = OssDownloader(
dataset_context_config=self.dataset_context_config)
oss_downloader.process()
# download statistics
self.api.dataset_download_statistics(
dataset_name=dataset_name,
namespace=namespace,
use_streaming=use_streaming)
return oss_downloader.dataset
else:
raise f'Expected remote data loader type: {RemoteDataLoaderType.HF_DATA_LOADER.value}/' \
f'{RemoteDataLoaderType.MS_DATA_LOADER.value}, but got {data_loader_type} .'