Files
modelscope/tests/msdatasets/test_virgo_dataset.py
xingjun.wxj e630621599 Virgo SDK supports odps data source
1. Support ODPS datasource for virgo sdk.
2. Adapt "inner_url" parser for single-modal and multi-modal datasets.

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12583695

* add virgo sdk odps data

* refine virgo sdk args and odps fetch batch data pipeline

* add inner import for odps

* del import VirgoDataset in MsDataset

* fix import VirgoDataset

* support inner url downloading

* refine dataset.py and maxcompute_utils.py

* add ut for virgo odps data batch

* reset unifold ut level as 1

* refine virgo batch data
2023-05-12 17:27:19 +08:00

97 lines
3.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from modelscope.hub.api import HubApi
from modelscope.msdatasets import MsDataset
from modelscope.msdatasets.dataset_cls.dataset import VirgoDataset
from modelscope.utils.constant import DownloadMode, Hubs, VirgoDatasetConfig
from modelscope.utils.logger import get_logger
logger = get_logger()
# Please use your own access token for buc account.
YOUR_ACCESS_TOKEN = 'your_access_token'
# Please use your own virgo dataset id and ensure you have access to it.
VIRGO_DATASET_ID = 'your_virgo_dataset_id'
class TestVirgoDataset(unittest.TestCase):
def setUp(self):
self.api = HubApi()
self.api.login(YOUR_ACCESS_TOKEN)
@unittest.skip('to be used for local test only')
def test_download_virgo_dataset_meta(self):
ds = MsDataset.load(dataset_name=VIRGO_DATASET_ID, hub=Hubs.virgo)
ds_one = next(iter(ds))
logger.info(ds_one)
self.assertTrue(ds_one)
self.assertIsInstance(ds, VirgoDataset)
self.assertIn(VirgoDatasetConfig.col_id, ds_one)
self.assertIn(VirgoDatasetConfig.col_meta_info, ds_one)
self.assertIn(VirgoDatasetConfig.col_analysis_result, ds_one)
self.assertIn(VirgoDatasetConfig.col_external_info, ds_one)
@unittest.skip('to be used for local test only')
def test_download_virgo_dataset_files(self):
ds = MsDataset.load(
dataset_name=VIRGO_DATASET_ID,
hub=Hubs.virgo,
download_virgo_files=True)
ds_one = next(iter(ds))
logger.info(ds_one)
self.assertTrue(ds_one)
self.assertIsInstance(ds, VirgoDataset)
self.assertTrue(ds.download_virgo_files)
self.assertIn(VirgoDatasetConfig.col_cache_file, ds_one)
cache_file_path = ds_one[VirgoDatasetConfig.col_cache_file]
self.assertTrue(os.path.exists(cache_file_path))
@unittest.skip('to be used for local test only')
def test_force_download_virgo_dataset_files(self):
ds = MsDataset.load(
dataset_name=VIRGO_DATASET_ID,
hub=Hubs.virgo,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
download_virgo_files=True)
ds_one = next(iter(ds))
logger.info(ds_one)
self.assertTrue(ds_one)
self.assertIsInstance(ds, VirgoDataset)
self.assertTrue(ds.download_virgo_files)
self.assertIn(VirgoDatasetConfig.col_cache_file, ds_one)
cache_file_path = ds_one[VirgoDatasetConfig.col_cache_file]
self.assertTrue(os.path.exists(cache_file_path))
@unittest.skip('to be used for local test only')
def test_download_virgo_dataset_odps(self):
# Note: the samplingType must be 1, which means to get the dataset from MaxCompute(ODPS).
import pandas as pd
ds = MsDataset.load(
dataset_name=VIRGO_DATASET_ID,
hub=Hubs.virgo,
odps_batch_size=100,
odps_limit=2000,
odps_drop_last=True)
ds_one = next(iter(ds))
logger.info(ds_one)
self.assertTrue(ds_one)
self.assertIsInstance(ds, VirgoDataset)
self.assertTrue(ds_one, pd.DataFrame)
logger.info(f'The shape of sample: {ds_one.shape}')
if __name__ == '__main__':
unittest.main()