Files
modelscope/tests/hub/test_hub_operation.py

206 lines
7.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import uuid
from shutil import rmtree
import requests
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
logger = get_logger()
DEFAULT_GIT_PATH = 'git'
download_model_file_name = 'test.bin'
class HubOperationTest(unittest.TestCase):
def setUp(self):
self.api = HubApi()
self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = 'op-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME,
)
def tearDown(self):
try:
self.api.delete_model(model_id=self.model_id)
except Exception as e:
logger.warning(f'delete model {self.model_id} failed, {e}')
def prepare_case(self):
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name))
repo.push('add model')
repo.tag_and_push(self.revision, 'Test revision')
def test_model_repo_creation(self):
# change to proper model names before use.
try:
info = self.api.get_model(model_id=self.model_id)
assert info['Name'] == self.model_name
except KeyError as ke:
if ke.args[0] == 'name':
print(f'model {self.model_name} already exists, ignore')
else:
raise
def test_download_single_file(self):
self.prepare_case()
downloaded_file = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision)
assert os.path.exists(downloaded_file)
mdtime1 = os.path.getmtime(downloaded_file)
# download again
downloaded_file = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
mdtime2 = os.path.getmtime(downloaded_file)
assert mdtime1 == mdtime2
def test_snapshot_download(self):
self.prepare_case()
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
mdtime1 = os.path.getmtime(downloaded_file_path)
# download again
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2
def test_download_public_without_login(self):
try:
self.prepare_case()
rmtree(ModelScopeConfig.path_credential)
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
temporary_dir = tempfile.mkdtemp()
downloaded_file = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision,
cache_dir=temporary_dir)
assert os.path.exists(downloaded_file)
finally:
self.api.login(TEST_ACCESS_TOKEN1)
def test_snapshot_delete_download_cache_file(self):
self.prepare_case()
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
os.remove(downloaded_file_path)
# download again in cache
file_download_path = model_file_download(
model_id=self.model_id,
file_path=ModelFile.README,
revision=self.revision)
assert os.path.exists(file_download_path)
# deleted file need download again
file_download_path = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision)
assert os.path.exists(file_download_path)
def test_snapshot_download_default_revision(self):
pass # TOTO
def test_file_download_default_revision(self):
pass # TODO
def get_model_download_times(self):
url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads'
cookies = ModelScopeConfig.get_cookies()
r = requests.get(url, cookies=cookies)
if r.status_code == 200:
return r.json()['Data']['Downloads']
else:
r.raise_for_status()
return None
@unittest.skip('temp skip')
def test_list_model(self):
data = self.api.list_models(TEST_MODEL_ORG)
assert len(data['Models']) >= 1
def test_snapshot_download_location(self):
self.prepare_case()
snapshot_download_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
print(snapshot_download_path)
assert os.path.exists(snapshot_download_path)
assert 'models' in snapshot_download_path
shutil.rmtree(snapshot_download_path)
# download with cache_dir
cache_dir = '/tmp/snapshot_download_cache_test'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, cache_dir=cache_dir)
expect_path = os.path.join(cache_dir, self.model_id)
assert snapshot_download_path == expect_path
assert os.path.exists(
os.path.join(snapshot_download_path, ModelFile.README))
shutil.rmtree(cache_dir)
# download with local_dir
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
shutil.rmtree(local_dir)
# download with local_dir and cache dir, with local first.
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=cache_dir,
local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
def test_snapshot_download_ignore_file_pattern_test(self):
self.prepare_case()
snapshot_download_path = snapshot_download(
model_id=self.model_id,
revision=self.revision,
ignore_file_pattern=['.*.pt', '.*.safetensors', '.*.bin'])
for _, _, files in os.walk(snapshot_download_path):
for file in files:
assert not file.endswith('pt') and not file.endswith(
'safetensors') and not file.endswith('bin')
if __name__ == '__main__':
unittest.main()