[Fix] fix token auth for downloading (#1674)

This commit is contained in:
Xingjun.Wang
2026-04-11 14:07:46 +08:00
committed by GitHub
parent 87ea8623cf
commit 32d7c7062c
4 changed files with 83 additions and 15 deletions

View File

@@ -145,7 +145,8 @@ class DownloadCMD(CLICommand):
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision,
cookies=cookies)
cookies=cookies,
token=self.args.token)
elif len(
self.args.files) > 1: # download specified multiple files.
snapshot_download(
@@ -155,7 +156,8 @@ class DownloadCMD(CLICommand):
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
cookies=cookies)
cookies=cookies,
token=self.args.token)
else: # download repo
snapshot_download(
self.args.model,
@@ -165,7 +167,8 @@ class DownloadCMD(CLICommand):
allow_file_pattern=convert_patterns(self.args.include),
ignore_file_pattern=convert_patterns(self.args.exclude),
max_workers=self.args.max_workers,
cookies=cookies)
cookies=cookies,
token=self.args.token)
print(f'\nSuccessfully Downloaded from model {self.args.model}.\n')
elif self.args.dataset:
dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
@@ -176,7 +179,8 @@ class DownloadCMD(CLICommand):
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=dataset_revision,
cookies=cookies)
cookies=cookies,
token=self.args.token)
elif len(
self.args.files) > 1: # download specified multiple files.
dataset_snapshot_download(
@@ -186,7 +190,8 @@ class DownloadCMD(CLICommand):
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
cookies=cookies)
cookies=cookies,
token=self.args.token)
else: # download repo
dataset_snapshot_download(
self.args.dataset,
@@ -196,7 +201,8 @@ class DownloadCMD(CLICommand):
allow_file_pattern=convert_patterns(self.args.include),
ignore_file_pattern=convert_patterns(self.args.exclude),
max_workers=self.args.max_workers,
cookies=cookies)
cookies=cookies,
token=self.args.token)
print(
f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
)

View File

@@ -308,7 +308,7 @@ def _snapshot_download(
_api = HubApi(token=token)
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=repo_type)
repo_id=repo_id, repo_type=repo_type, token=token)
if cookies is None:
cookies = _api.get_cookies()
if repo_type == REPO_TYPE_MODEL:
@@ -393,8 +393,8 @@ def _snapshot_download(
revision_detail = revision or DEFAULT_DATASET_REVISION
logger.info('Fetching dataset repo file list...')
repo_files = fetch_repo_files(_api, repo_id, revision_detail,
endpoint)
repo_files = fetch_repo_files(
_api, repo_id, revision_detail, endpoint, token=token)
if repo_files is None:
logger.error(
@@ -427,10 +427,13 @@ def _snapshot_download(
return cache_root_path
def fetch_repo_files(_api, repo_id, revision, endpoint):
def fetch_repo_files(_api, repo_id, revision, endpoint, token=None):
_owner, _dataset_name = repo_id.split('/')
_hub_id, _ = _api.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
dataset_name=_dataset_name,
namespace=_owner,
endpoint=endpoint,
token=token)
page_number = 1
page_size = 150
@@ -446,6 +449,7 @@ def fetch_repo_files(_api, repo_id, revision, endpoint):
page_number=page_number,
page_size=page_size,
endpoint=endpoint,
token=token,
dataset_hub_id=_hub_id)
except Exception as e:
logger.error(f'Error fetching dataset files: {e}')

View File

@@ -12,7 +12,7 @@ from modelscope.hub.repository import Repository
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
TEST_MODEL_ORG, test_level)
logger = get_logger()
@@ -58,21 +58,25 @@ class DownloadCMDTest(unittest.TestCase):
logger.warning(f'Error deleting model {self.model_id}: {e}')
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_download(self):
cmd = f'python -m modelscope.cli.cli download --model {self.model_id}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_download_with_position_arg(self):
cmd = f'python -m modelscope.cli.cli download {self.model_id}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_download_file(self):
cmd = f'python -m modelscope.cli.cli download --model {self.model_id} {download_model_file_name}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_download_with_cache(self):
cmd = f'python -m modelscope.cli.cli download --model {self.model_id} --cache_dir {self.tmp_dir}'
stat, output = subprocess.getstatusoutput(cmd)
@@ -83,6 +87,7 @@ class DownloadCMDTest(unittest.TestCase):
osp.exists(
f'{self.tmp_dir}/{self.model_id}/{download_model_file_name}'))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_download_with_revision(self):
cmd = f'python -m modelscope.cli.cli download --model {self.model_id} --revision {self.revision}'
stat, output = subprocess.getstatusoutput(cmd)
@@ -91,5 +96,58 @@ class DownloadCMDTest(unittest.TestCase):
self.assertEqual(stat, 0)
class DownloadCMDTokenTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.api = HubApi()
# Create private model repo
self.model_id = '%s/%s' % (TEST_MODEL_ORG, 'test_model_with_token')
self.api.create_repo(
repo_id=self.model_id,
repo_type='model',
visibility='private',
license=Licenses.APACHE_V2,
chinese_name=TEST_MODEL_CHINESE_NAME,
exist_ok=True,
)
# Create private dataset repo
self.dataset_id = '%s/%s' % (TEST_MODEL_ORG, 'test_dataset_with_token')
self.api.create_repo(
repo_id=self.dataset_id,
repo_type='dataset',
visibility='private',
license=Licenses.APACHE_V2,
exist_ok=True,
)
def tearDown(self):
try:
self.api.delete_model(model_id=self.model_id)
except Exception as e:
logger.warning(f'Error deleting model {self.model_id}: {e}')
try:
self.api.delete_dataset(dataset_id=self.dataset_id)
except Exception as e:
logger.warning(f'Error deleting dataset {self.dataset_id}: {e}')
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_download_model_with_token(self):
cmd = f'python -m modelscope.cli.cli download --model {self.model_id} --token {TEST_ACCESS_TOKEN1}'
stat, output = subprocess.getstatusoutput(cmd)
print(output)
self.assertEqual(stat, 0)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_download_dataset_with_token(self):
cmd = f'python -m modelscope.cli.cli download --dataset {self.dataset_id} --token {TEST_ACCESS_TOKEN1}'
stat, output = subprocess.getstatusoutput(cmd)
print(output)
self.assertEqual(stat, 0)
if __name__ == '__main__':
unittest.main()

View File

@@ -15,7 +15,7 @@ class MCPApiTest(unittest.TestCase):
self.api = MCPApi()
self.api.login(TEST_ACCESS_TOKEN1)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_list_mcp_servers(self):
"""Test list_mcp_servers functionality and validation."""
result = self.api.list_mcp_servers(total_count=5)
@@ -31,7 +31,7 @@ class MCPApiTest(unittest.TestCase):
for field in ['name', 'id', 'description']:
self.assertIn(field, server)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_list_operational_mcp_servers(self):
"""Test list_operational_mcp_servers functionality."""
result = self.api.list_operational_mcp_servers()
@@ -53,7 +53,7 @@ class MCPApiTest(unittest.TestCase):
self.assertIn('url', first_config)
self.assertTrue(first_config['url'].startswith('https://'))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_get_mcp_server(self):
"""Test get_mcp_server functionality and validation."""
result = self.api.get_mcp_server('@modelcontextprotocol/fetch')