Files
modelscope/tests/hub/test_hub_upload.py
2025-03-04 15:18:37 +08:00

190 lines
7.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import time
import unittest
import uuid
from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError, HTTPError, NotLoginException
from modelscope.hub.push_to_hub import push_to_hub, push_to_hub_async
from modelscope.hub.repository import Repository
from modelscope.utils.constant import REPO_TYPE_DATASET, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1, TEST_MODEL_ORG,
delete_credential, test_level)
logger = get_logger()
class HubUploadTest(unittest.TestCase):
def setUp(self):
logger.info('SetUp')
self.api = HubApi()
self.user = TEST_MODEL_ORG
logger.info(self.user)
self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload',
uuid.uuid4().hex)
logger.info('create %s' % self.create_model_name)
temporary_dir = tempfile.mkdtemp()
self.work_dir = temporary_dir
self.model_dir = os.path.join(temporary_dir, self.create_model_name)
self.finetune_path = os.path.join(self.work_dir, 'finetune_path')
self.repo_path = os.path.join(self.work_dir, 'repo_path')
os.mkdir(self.finetune_path)
os.system("echo '{}'>%s"
% os.path.join(self.finetune_path, ModelFile.CONFIGURATION))
os.environ['MODELSCOPE_TRAIN_ID'] = 'test-id'
def tearDown(self):
logger.info('TearDown')
shutil.rmtree(self.model_dir, ignore_errors=True)
try:
self.api.delete_model(model_id=self.create_model_name)
except Exception:
pass
def test_repo_exist(self):
res = self.api.repo_exists('Qwen/Qwen2.5-7B-Instruct')
self.assertTrue(res)
res = self.api.repo_exists('Qwen/not-a-repo')
self.assertFalse(res)
res = self.api.repo_exists(
'Qwen/ProcessBench', repo_type=REPO_TYPE_DATASET)
self.assertTrue(res)
res = self.api.repo_exists(
'Qwen/not-a-repo', repo_type=REPO_TYPE_DATASET)
self.assertFalse(res)
def test_upload_exits_repo_master(self):
logger.info('basic test for upload!')
self.api.login(TEST_ACCESS_TOKEN1)
self.api.create_model(
model_id=self.create_model_name,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
os.system("echo '111'>%s"
% os.path.join(self.finetune_path, 'add1.py'))
self.api.push_model(
model_id=self.create_model_name, model_dir=self.finetune_path)
Repository(model_dir=self.repo_path, clone_from=self.create_model_name)
assert os.path.exists(os.path.join(self.repo_path, 'add1.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
os.system("echo '222'>%s"
% os.path.join(self.finetune_path, 'add2.py'))
self.api.push_model(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version1')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version1')
assert os.path.exists(os.path.join(self.repo_path, 'add2.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
os.system("echo '333'>%s"
% os.path.join(self.finetune_path, 'add3.py'))
self.api.push_model(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version2',
commit_message='add add3.py')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version2')
assert os.path.exists(os.path.join(self.repo_path, 'add2.py'))
assert os.path.exists(os.path.join(self.repo_path, 'add3.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
add4_path = os.path.join(self.finetune_path, 'temp')
os.mkdir(add4_path)
os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py'))
self.api.push_model(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version1')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version1')
assert os.path.exists(os.path.join(add4_path, 'add4.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
assert os.path.exists(os.path.join(self.finetune_path, 'add3.py'))
os.remove(os.path.join(self.finetune_path, 'add3.py'))
self.api.push_model(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_revision/version1')
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_revision/version1')
assert not os.path.exists(os.path.join(self.repo_path, 'add3.py'))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_upload_non_exists_repo(self):
logger.info('test upload non exists repo!')
self.api.login(TEST_ACCESS_TOKEN1)
os.system("echo '111'>%s"
% os.path.join(self.finetune_path, 'add1.py'))
self.api.push_model(
model_id=self.create_model_name,
model_dir=self.finetune_path,
revision='new_model_new_revision',
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
Repository(
model_dir=self.repo_path,
clone_from=self.create_model_name,
revision='new_model_new_revision')
assert os.path.exists(os.path.join(self.repo_path, 'add1.py'))
shutil.rmtree(self.repo_path, ignore_errors=True)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_upload_without_token(self):
logger.info('test upload without login!')
self.api.login(TEST_ACCESS_TOKEN1)
delete_credential()
with self.assertRaises(NotLoginException):
self.api.push_model(
model_id=self.create_model_name,
model_dir=self.finetune_path,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_upload_invalid_repo(self):
logger.info('test upload to invalid repo!')
self.api.login(TEST_ACCESS_TOKEN1)
with self.assertRaises((HTTPError, GitError)):
self.api.push_model(
model_id='%s/%s' % ('speech_tts', 'invalid_model_test'),
model_dir=self.finetune_path,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_push_to_hub(self):
ret = push_to_hub(
repo_name=self.create_model_name,
output_dir=self.finetune_path,
token=TEST_ACCESS_TOKEN1)
self.assertTrue(ret is True)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_push_to_hub_async(self):
future = push_to_hub_async(
repo_name=self.create_model_name,
output_dir=self.finetune_path,
token=TEST_ACCESS_TOKEN1)
while not future.done():
time.sleep(1)
self.assertTrue(future.result())
if __name__ == '__main__':
unittest.main()