mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-17 00:37:43 +01:00
97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import os
|
|
import unittest
|
|
|
|
from modelscope.hub.api import HubApi
|
|
from modelscope.msdatasets import MsDataset
|
|
from modelscope.msdatasets.dataset_cls.dataset import VirgoDataset
|
|
from modelscope.utils.constant import DownloadMode, Hubs, VirgoDatasetConfig
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
# Please use your own access token for buc account.
|
|
YOUR_ACCESS_TOKEN = 'your_access_token'
|
|
# Please use your own virgo dataset id and ensure you have access to it.
|
|
VIRGO_DATASET_ID = 'your_virgo_dataset_id'
|
|
|
|
|
|
class TestVirgoDataset(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.api = HubApi()
|
|
self.api.login(YOUR_ACCESS_TOKEN)
|
|
|
|
@unittest.skip('to be used for local test only')
|
|
def test_download_virgo_dataset_meta(self):
|
|
ds = MsDataset.load(dataset_name=VIRGO_DATASET_ID, hub=Hubs.virgo)
|
|
ds_one = next(iter(ds))
|
|
logger.info(ds_one)
|
|
|
|
self.assertTrue(ds_one)
|
|
self.assertIsInstance(ds, VirgoDataset)
|
|
self.assertIn(VirgoDatasetConfig.col_id, ds_one)
|
|
self.assertIn(VirgoDatasetConfig.col_meta_info, ds_one)
|
|
self.assertIn(VirgoDatasetConfig.col_analysis_result, ds_one)
|
|
self.assertIn(VirgoDatasetConfig.col_external_info, ds_one)
|
|
|
|
@unittest.skip('to be used for local test only')
|
|
def test_download_virgo_dataset_files(self):
|
|
ds = MsDataset.load(
|
|
dataset_name=VIRGO_DATASET_ID,
|
|
hub=Hubs.virgo,
|
|
download_virgo_files=True)
|
|
|
|
ds_one = next(iter(ds))
|
|
logger.info(ds_one)
|
|
|
|
self.assertTrue(ds_one)
|
|
self.assertIsInstance(ds, VirgoDataset)
|
|
self.assertTrue(ds.download_virgo_files)
|
|
self.assertIn(VirgoDatasetConfig.col_cache_file, ds_one)
|
|
cache_file_path = ds_one[VirgoDatasetConfig.col_cache_file]
|
|
self.assertTrue(os.path.exists(cache_file_path))
|
|
|
|
@unittest.skip('to be used for local test only')
|
|
def test_force_download_virgo_dataset_files(self):
|
|
ds = MsDataset.load(
|
|
dataset_name=VIRGO_DATASET_ID,
|
|
hub=Hubs.virgo,
|
|
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
|
download_virgo_files=True)
|
|
|
|
ds_one = next(iter(ds))
|
|
logger.info(ds_one)
|
|
|
|
self.assertTrue(ds_one)
|
|
self.assertIsInstance(ds, VirgoDataset)
|
|
self.assertTrue(ds.download_virgo_files)
|
|
self.assertIn(VirgoDatasetConfig.col_cache_file, ds_one)
|
|
cache_file_path = ds_one[VirgoDatasetConfig.col_cache_file]
|
|
self.assertTrue(os.path.exists(cache_file_path))
|
|
|
|
@unittest.skip('to be used for local test only')
|
|
def test_download_virgo_dataset_odps(self):
|
|
# Note: the samplingType must be 1, which means to get the dataset from MaxCompute(ODPS).
|
|
import pandas as pd
|
|
|
|
ds = MsDataset.load(
|
|
dataset_name=VIRGO_DATASET_ID,
|
|
hub=Hubs.virgo,
|
|
odps_batch_size=100,
|
|
odps_limit=2000,
|
|
odps_drop_last=True)
|
|
|
|
ds_one = next(iter(ds))
|
|
logger.info(ds_one)
|
|
|
|
self.assertTrue(ds_one)
|
|
self.assertIsInstance(ds, VirgoDataset)
|
|
self.assertTrue(ds_one, pd.DataFrame)
|
|
logger.info(f'The shape of sample: {ds_one.shape}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|