mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
Feat/collections (#1656)
This commit is contained in:
@@ -13,6 +13,7 @@ from modelscope.cli.pipeline import PipelineCMD
|
||||
from modelscope.cli.plugins import PluginsCMD
|
||||
from modelscope.cli.scancache import ScanCacheCMD
|
||||
from modelscope.cli.server import ServerCMD
|
||||
from modelscope.cli.skills import SkillsCMD
|
||||
from modelscope.cli.upload import UploadCMD
|
||||
from modelscope.hub.constants import MODELSCOPE_ASCII
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -36,6 +37,7 @@ def run_cmd():
|
||||
|
||||
CreateCMD.define_args(subparsers)
|
||||
DownloadCMD.define_args(subparsers)
|
||||
SkillsCMD.define_args(subparsers)
|
||||
UploadCMD.define_args(subparsers)
|
||||
ClearCacheCMD.define_args(subparsers)
|
||||
PluginsCMD.define_args(subparsers)
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import logging
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.cli.utils import concurrent_download
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
|
||||
from modelscope.hub.constants import DEFAULT_MAX_WORKERS, DEFAULT_SKILLS_DIR
|
||||
from modelscope.hub.file_download import (dataset_file_download,
|
||||
model_file_download)
|
||||
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
|
||||
snapshot_download)
|
||||
from modelscope.hub.utils.utils import convert_patterns
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(log_level=logging.WARNING)
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
@@ -41,6 +45,11 @@ class DownloadCMD(CLICommand):
|
||||
type=str,
|
||||
help='The id of the dataset to be downloaded. For download, '
|
||||
'the id of either a model or dataset must be provided.')
|
||||
group.add_argument(
|
||||
'--collection',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The ID of the collection to download (skills only)')
|
||||
parser.add_argument(
|
||||
'repo_id',
|
||||
type=str,
|
||||
@@ -122,8 +131,8 @@ class DownloadCMD(CLICommand):
|
||||
else:
|
||||
raise Exception('Not support repo-type: %s'
|
||||
% self.args.repo_type)
|
||||
if not self.args.model and not self.args.dataset:
|
||||
raise Exception('Model or dataset must be set.')
|
||||
if not self.args.model and not self.args.dataset and not self.args.collection:
|
||||
raise Exception('Model, dataset, or collection must be set.')
|
||||
cookies = None
|
||||
if self.args.token is not None:
|
||||
api = HubApi()
|
||||
@@ -191,5 +200,54 @@ class DownloadCMD(CLICommand):
|
||||
print(
|
||||
f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
|
||||
)
|
||||
elif self.args.collection:
|
||||
api = HubApi(token=self.args.token)
|
||||
local_dir = self.args.local_dir or DEFAULT_SKILLS_DIR
|
||||
data = api.get_collection(self.args.collection, repo_type='skill')
|
||||
elements = data.get('CollectionElements',
|
||||
{}).get('CollectionElementVoList', [])
|
||||
|
||||
logger.info(
|
||||
f'Collection {self.args.collection} has {len(elements)} elements.'
|
||||
)
|
||||
|
||||
if not elements:
|
||||
print(f'No skill elements found in collection: '
|
||||
f'{self.args.collection}')
|
||||
return
|
||||
|
||||
# Validate elements have required fields
|
||||
valid_elements = []
|
||||
for elem in elements:
|
||||
if not elem.get('ElementPath') or not elem.get('ElementName'):
|
||||
logger.warning('Skipping malformed collection element: %s',
|
||||
elem)
|
||||
continue
|
||||
valid_elements.append(elem)
|
||||
|
||||
if not valid_elements:
|
||||
print(f'No valid skill elements found in collection: '
|
||||
f'{self.args.collection}')
|
||||
return
|
||||
|
||||
print(f'Found {len(valid_elements)} skill(s) in collection, '
|
||||
f'downloading...')
|
||||
|
||||
def _download_one_skill(element):
|
||||
element_path = element['ElementPath']
|
||||
element_name = element['ElementName']
|
||||
skill_id = f'{element_path}/{element_name}'
|
||||
try:
|
||||
skill_dir = api.download_skill(
|
||||
skill_id=skill_id, local_dir=local_dir)
|
||||
return (skill_id, skill_dir, None)
|
||||
except Exception as e:
|
||||
return (skill_id, None, str(e))
|
||||
|
||||
concurrent_download(
|
||||
_download_one_skill,
|
||||
valid_elements,
|
||||
max_workers=self.args.max_workers,
|
||||
item_name='skill')
|
||||
else:
|
||||
pass # noop
|
||||
|
||||
100
modelscope/cli/skills.py
Normal file
100
modelscope/cli/skills.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import logging
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.cli.utils import concurrent_download
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import DEFAULT_SKILLS_DIR
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(log_level=logging.WARNING)
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
"""Function which will be called for a specific sub parser."""
|
||||
return SkillsCMD(args)
|
||||
|
||||
|
||||
class SkillsCMD(CLICommand):
|
||||
"""Command for managing skills."""
|
||||
|
||||
name = 'skills'
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
"""Define args for skills command."""
|
||||
parser = parsers.add_parser(SkillsCMD.name)
|
||||
subparsers = parser.add_subparsers(
|
||||
dest='skills_action', help='skills subcommands')
|
||||
|
||||
# 'add' subcommand
|
||||
add_parser = subparsers.add_parser(
|
||||
'add', help='Download and install skills')
|
||||
add_parser.add_argument(
|
||||
'skill_ids',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='Skill IDs to download, in format: <path>/<name>')
|
||||
add_parser.add_argument(
|
||||
'--token',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Access token for authentication')
|
||||
add_parser.add_argument(
|
||||
'--local_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Target directory for skills (default: ~/.agents/skills)')
|
||||
add_parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
default=8,
|
||||
help='Maximum concurrent downloads (default: 8)')
|
||||
add_parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
if not hasattr(self.args,
|
||||
'skills_action') or not self.args.skills_action:
|
||||
print('Usage: modelscope skills add <skill_id1> <skill_id2> ...')
|
||||
return
|
||||
|
||||
if not hasattr(self.args, 'skill_ids') or not self.args.skill_ids:
|
||||
print('No skill IDs provided. Usage: modelscope skills add '
|
||||
'<skill_id1> <skill_id2> ...')
|
||||
return
|
||||
|
||||
api = HubApi(token=self.args.token)
|
||||
local_dir = self.args.local_dir or DEFAULT_SKILLS_DIR
|
||||
|
||||
skill_ids = self.args.skill_ids
|
||||
print(f'Downloading {len(skill_ids)} skill(s)...')
|
||||
|
||||
if len(skill_ids) == 1:
|
||||
# Single skill download
|
||||
try:
|
||||
skill_dir = api.download_skill(
|
||||
skill_id=skill_ids[0], local_dir=local_dir)
|
||||
print(f'Skill downloaded to: {skill_dir}')
|
||||
except Exception as e:
|
||||
print(f'Failed to download skill {skill_ids[0]}: {e}')
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Multiple skills - concurrent download
|
||||
def _download_one(skill_id):
|
||||
try:
|
||||
skill_dir = api.download_skill(
|
||||
skill_id=skill_id, local_dir=local_dir)
|
||||
return (skill_id, skill_dir, None)
|
||||
except Exception as e:
|
||||
return (skill_id, None, str(e))
|
||||
|
||||
concurrent_download(
|
||||
_download_one,
|
||||
skill_ids,
|
||||
max_workers=self.args.max_workers,
|
||||
item_name='skill')
|
||||
41
modelscope/cli/utils.py
Normal file
41
modelscope/cli/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
def concurrent_download(download_fn, items, max_workers=8, item_name='item'):
|
||||
"""Download multiple items concurrently with progress reporting.
|
||||
|
||||
Args:
|
||||
download_fn: Callable that takes an item and returns
|
||||
(identifier, result_path, error_string_or_None).
|
||||
items: List of items to download.
|
||||
max_workers (int): Maximum concurrent workers.
|
||||
item_name (str): Display name for the item type.
|
||||
|
||||
Returns:
|
||||
tuple: (succeeded_list, failed_list).
|
||||
"""
|
||||
succeeded = []
|
||||
failed = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(download_fn, item): item for item in items}
|
||||
for future in as_completed(futures):
|
||||
identifier, result_path, error = future.result()
|
||||
if error:
|
||||
failed.append((identifier, error))
|
||||
print(f'Failed to download {item_name} {identifier}: {error}')
|
||||
else:
|
||||
succeeded.append((identifier, result_path))
|
||||
print(f'Downloaded {item_name} {identifier} -> {result_path}')
|
||||
|
||||
print(f'\nDownload complete: {len(succeeded)} succeeded, '
|
||||
f'{len(failed)} failed')
|
||||
if failed:
|
||||
print(f'Failed {item_name}s:')
|
||||
for identifier, error in failed:
|
||||
print(f' {identifier}: {error}')
|
||||
sys.exit(1)
|
||||
|
||||
return succeeded, failed
|
||||
@@ -13,6 +13,7 @@ import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from http.cookiejar import CookieJar
|
||||
@@ -3043,6 +3044,111 @@ class HubApi:
|
||||
|
||||
return resp
|
||||
|
||||
# ============= Collection API =============
|
||||
def get_collection(self,
|
||||
collection_id: str,
|
||||
repo_type: str = 'skill',
|
||||
page_number: int = 1,
|
||||
page_size: int = 50) -> dict:
|
||||
"""Get collection details and its elements.
|
||||
|
||||
Args:
|
||||
collection_id (str): The collection ID (Fid).
|
||||
repo_type (str): Element type filter, only 'skill' is supported currently.
|
||||
page_number (int): Page number for pagination.
|
||||
page_size (int): Page size for pagination.
|
||||
|
||||
Returns:
|
||||
dict: Collection details including elements.
|
||||
|
||||
Raises:
|
||||
ValueError: If repo_type is not 'skill'.
|
||||
RequestError: If the API request fails.
|
||||
"""
|
||||
if repo_type != 'skill':
|
||||
raise ValueError(
|
||||
f'repo_type={repo_type} is not supported, '
|
||||
'only "skill" is currently supported.')
|
||||
cookies = self.get_cookies()
|
||||
path = f'{self.endpoint}/api/v1/collections'
|
||||
params = {
|
||||
'Fid': collection_id,
|
||||
'ElementType': repo_type,
|
||||
'PageNumber': page_number,
|
||||
'PageSize': page_size,
|
||||
}
|
||||
r = self.session.get(path, params=params, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
return d[API_RESPONSE_FIELD_DATA]
|
||||
|
||||
def download_skill(self, skill_id: str,
|
||||
local_dir: Optional[str] = None) -> str:
|
||||
"""Download a single skill archive and extract it.
|
||||
|
||||
Args:
|
||||
skill_id (str): The skill identifier in format '<path>/<name>'.
|
||||
local_dir (Optional[str]): Target directory for extraction.
|
||||
Defaults to current directory.
|
||||
|
||||
Returns:
|
||||
str: Path to the extracted skill directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If skill_id format is invalid.
|
||||
RequestError: If the download request fails.
|
||||
"""
|
||||
element_path, element_name = RepoUtils.validate_repo_id(skill_id)
|
||||
|
||||
cookies = self.get_cookies()
|
||||
url = f'{self.endpoint}/api/v1/skills/{element_path}/{element_name}/archive/zip/master'
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = os.getcwd()
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
# Build skill directory name: <element_path>__<element_name>__master
|
||||
skill_dir_name = f'{element_path}__{element_name}__master'
|
||||
skill_dir = os.path.join(local_dir, skill_dir_name)
|
||||
|
||||
r = self.session.get(url, stream=True, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
|
||||
# Save to temp zip file then extract
|
||||
zip_path = os.path.join(local_dir, f'{element_name}.zip')
|
||||
try:
|
||||
with open(zip_path, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
# Clean existing directory to avoid corrupted state
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
os.makedirs(skill_dir, exist_ok=True)
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
zf.extractall(skill_dir)
|
||||
|
||||
# Flatten if zip contains a single top-level directory
|
||||
entries = os.listdir(skill_dir)
|
||||
if len(entries) == 1:
|
||||
nested_dir = os.path.join(skill_dir, entries[0])
|
||||
if os.path.isdir(nested_dir):
|
||||
for item in os.listdir(nested_dir):
|
||||
shutil.move(
|
||||
os.path.join(nested_dir, item),
|
||||
os.path.join(skill_dir, item))
|
||||
os.rmdir(nested_dir)
|
||||
finally:
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
|
||||
logger.info(f'Skill {element_path}/{element_name} downloaded to {skill_dir}')
|
||||
return skill_dir
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser(MODELSCOPE_CREDENTIALS_PATH)
|
||||
|
||||
@@ -41,6 +41,7 @@ TEMPORARY_FOLDER_NAME = '._____temp'
|
||||
DEFAULT_MAX_WORKERS = int(
|
||||
os.getenv('DEFAULT_MAX_WORKERS', min(8,
|
||||
os.cpu_count() + 4)))
|
||||
DEFAULT_SKILLS_DIR = os.path.join(os.path.expanduser('~'), '.agents', 'skills')
|
||||
|
||||
# Upload check env
|
||||
UPLOAD_MAX_FILE_SIZE = int(
|
||||
|
||||
@@ -219,6 +219,28 @@ class RepoUtils:
|
||||
return pattern + '*'
|
||||
return pattern
|
||||
|
||||
@staticmethod
|
||||
def validate_repo_id(repo_id: str) -> tuple:
|
||||
"""Validate and parse a repo_id in '<owner>/<name>' format.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repo identifier, e.g. 'MiniMax-AI/minimax-pdf'.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple of (owner, name).
|
||||
|
||||
Raises:
|
||||
ValueError: If repo_id format is invalid.
|
||||
"""
|
||||
if not repo_id or '/' not in repo_id:
|
||||
raise ValueError(
|
||||
f'Invalid repo_id: {repo_id}, expected format: <owner>/<name>')
|
||||
parts = repo_id.split('/', 1)
|
||||
if len(parts) != 2 or not parts[0].strip() or not parts[1].strip():
|
||||
raise ValueError(
|
||||
f'Invalid repo_id: {repo_id}, expected format: <owner>/<name>')
|
||||
return parts[0].strip(), parts[1].strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitInfo:
|
||||
|
||||
@@ -25,6 +25,7 @@ Homepage = "https://github.com/modelscope/modelscope"
|
||||
|
||||
[project.scripts]
|
||||
modelscope = "modelscope.cli.cli:run_cmd"
|
||||
ms = "modelscope.cli.cli:run_cmd"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=69", "wheel"]
|
||||
|
||||
Reference in New Issue
Block a user