Files
modelscope/tests/hub/test_hub_operation.py
Yingda Chen 65075191f7 clean directory structure (#1201)
默认
模型:
~/.cache/modelscope/hub/models/{model_owner}/{model_name}
数据集:
~/.cache/modelscope/hub/datasets/{dataset_owner}/{dataset_name}

配置MODELSCOPE_CACHE环境变量:
模型:
$MODELSCOPE_CACHE/models/{model_owner}/{model_name}
数据集:
$MODELSCOPE_CACHE/datasets/{dataset_owner}/{dataset_name}

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
2025-01-23 17:31:49 +08:00

202 lines
7.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import uuid
from pathlib import Path
from shutil import rmtree
import requests
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
DEFAULT_GIT_PATH = 'git'
download_model_file_name = 'test.bin'
class HubOperationTest(unittest.TestCase):
def setUp(self):
self.api = HubApi()
self.api.login(TEST_ACCESS_TOKEN1)
self.model_name = 'op-%s' % (uuid.uuid4().hex)
self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
self.revision = 'v0.1_test_revision'
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME,
)
def tearDown(self):
self.api.delete_model(model_id=self.model_id)
def prepare_case(self):
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, download_model_file_name))
repo.push('add model')
repo.tag_and_push(self.revision, 'Test revision')
def test_model_repo_creation(self):
# change to proper model names before use.
try:
info = self.api.get_model(model_id=self.model_id)
assert info['Name'] == self.model_name
except KeyError as ke:
if ke.args[0] == 'name':
print(f'model {self.model_name} already exists, ignore')
else:
raise
def test_download_single_file(self):
self.prepare_case()
downloaded_file = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision)
assert os.path.exists(downloaded_file)
mdtime1 = os.path.getmtime(downloaded_file)
# download again
downloaded_file = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
mdtime2 = os.path.getmtime(downloaded_file)
assert mdtime1 == mdtime2
def test_snapshot_download(self):
self.prepare_case()
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
mdtime1 = os.path.getmtime(downloaded_file_path)
# download again
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2
def test_download_public_without_login(self):
try:
self.prepare_case()
rmtree(ModelScopeConfig.path_credential)
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
temporary_dir = tempfile.mkdtemp()
downloaded_file = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision,
cache_dir=temporary_dir)
assert os.path.exists(downloaded_file)
finally:
self.api.login(TEST_ACCESS_TOKEN1)
def test_snapshot_delete_download_cache_file(self):
self.prepare_case()
snapshot_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
os.remove(downloaded_file_path)
# download again in cache
file_download_path = model_file_download(
model_id=self.model_id,
file_path=ModelFile.README,
revision=self.revision)
assert os.path.exists(file_download_path)
# deleted file need download again
file_download_path = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
revision=self.revision)
assert os.path.exists(file_download_path)
def test_snapshot_download_default_revision(self):
pass # TOTO
def test_file_download_default_revision(self):
pass # TODO
def get_model_download_times(self):
url = f'{self.api.endpoint}/api/v1/models/{self.model_id}/downloads'
cookies = ModelScopeConfig.get_cookies()
r = requests.get(url, cookies=cookies)
if r.status_code == 200:
return r.json()['Data']['Downloads']
else:
r.raise_for_status()
return None
@unittest.skip('temp skip')
def test_list_model(self):
data = self.api.list_models(TEST_MODEL_ORG)
assert len(data['Models']) >= 1
def test_snapshot_download_location(self):
self.prepare_case()
snapshot_download_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
print(snapshot_download_path)
assert os.path.exists(snapshot_download_path)
assert 'models' in snapshot_download_path
shutil.rmtree(snapshot_download_path)
# download with cache_dir
cache_dir = '/tmp/snapshot_download_cache_test'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, cache_dir=cache_dir)
expect_path = os.path.join(cache_dir, self.model_id)
assert snapshot_download_path == expect_path
assert os.path.exists(
os.path.join(snapshot_download_path, ModelFile.README))
shutil.rmtree(cache_dir)
# download with local_dir
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
shutil.rmtree(local_dir)
# download with local_dir and cache dir, with local first.
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=cache_dir,
local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
def test_snapshot_download_ignore_file_pattern_test(self):
self.prepare_case()
snapshot_download_path = snapshot_download(
model_id=self.model_id,
revision=self.revision,
ignore_file_pattern=['.*.pt', '.*.safetensors', '.*.bin'])
for _, _, files in os.walk(snapshot_download_path):
for file in files:
assert not file.endswith('pt') and not file.endswith(
'safetensors') and not file.endswith('bin')
if __name__ == '__main__':
unittest.main()