mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
Merge remote-tracking branch 'origin' into feat/fill_mask
Conflicts: modelscope/utils/constant.py
This commit is contained in:
0
modelscope/hub/__init__.py
Normal file
0
modelscope/hub/__init__.py
Normal file
265
modelscope/hub/api.py
Normal file
265
modelscope/hub/api.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import imp
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
from http.cookiejar import CookieJar
|
||||
from os.path import expanduser
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import LOGGER_NAME
|
||||
from .errors import NotExistError, is_ok, raise_on_error
|
||||
from .utils.utils import get_endpoint, model_id_to_group_owner_name
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class HubApi:
|
||||
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
|
||||
def login(
|
||||
self,
|
||||
user_name: str,
|
||||
password: str,
|
||||
) -> tuple():
|
||||
"""
|
||||
Login with username and password
|
||||
|
||||
Args:
|
||||
username(`str`): user name on modelscope
|
||||
password(`str`): password
|
||||
|
||||
Returns:
|
||||
cookies: to authenticate yourself to ModelScope open-api
|
||||
gitlab token: to access private repos
|
||||
|
||||
<Tip>
|
||||
You only have to login once within 30 days.
|
||||
</Tip>
|
||||
|
||||
TODO: handle cookies expire
|
||||
|
||||
"""
|
||||
path = f'{self.endpoint}/api/v1/login'
|
||||
r = requests.post(
|
||||
path, json={
|
||||
'username': user_name,
|
||||
'password': password
|
||||
})
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
|
||||
token = d['Data']['AccessToken']
|
||||
cookies = r.cookies
|
||||
|
||||
# save token and cookie
|
||||
ModelScopeConfig.save_token(token)
|
||||
ModelScopeConfig.save_cookies(cookies)
|
||||
ModelScopeConfig.write_to_git_credential(user_name, password)
|
||||
|
||||
return d['Data']['AccessToken'], cookies
|
||||
|
||||
def create_model(self, model_id: str, chinese_name: str, visibility: int,
|
||||
license: str) -> str:
|
||||
"""
|
||||
Create model repo at ModelScopeHub
|
||||
|
||||
Args:
|
||||
model_id:(`str`): The model id
|
||||
chinese_name(`str`): chinese name of the model
|
||||
visibility(`int`): visibility of the model(1-private, 3-internal, 5-public)
|
||||
license(`str`): license of the model, candidates can be found at: TBA
|
||||
|
||||
Returns:
|
||||
name of the model created
|
||||
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models'
|
||||
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
||||
r = requests.post(
|
||||
path,
|
||||
json={
|
||||
'Path': owner_or_group,
|
||||
'Name': name,
|
||||
'ChineseName': chinese_name,
|
||||
'Visibility': visibility,
|
||||
'License': license
|
||||
},
|
||||
cookies=cookies)
|
||||
r.raise_for_status()
|
||||
raise_on_error(r.json())
|
||||
d = r.json()
|
||||
return d['Data']['Name']
|
||||
|
||||
def delete_model(self, model_id):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
model_id (str): The model id.
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}'
|
||||
|
||||
r = requests.delete(path, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
raise_on_error(r.json())
|
||||
|
||||
def get_model_url(self, model_id):
|
||||
return f'{self.endpoint}/api/v1/models/{model_id}.git'
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: str = 'master',
|
||||
) -> str:
|
||||
"""
|
||||
Get model information at modelscope_hub
|
||||
|
||||
Args:
|
||||
model_id(`str`): The model id.
|
||||
revision(`str`): revision of model
|
||||
Returns:
|
||||
The model details information.
|
||||
Raises:
|
||||
NotExistError: If the model is not exist, will throw NotExistError
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
||||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?{revision}'
|
||||
|
||||
r = requests.get(path, cookies=cookies)
|
||||
if r.status_code == 200:
|
||||
if is_ok(r.json()):
|
||||
return r.json()['Data']
|
||||
else:
|
||||
raise NotExistError(r.json()['Message'])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
def get_model_branches_and_tags(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
|
||||
r = requests.get(path, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
info = d['Data']
|
||||
branches = [x['Revision'] for x in info['RevisionMap']['Branches']
|
||||
] if info['RevisionMap']['Branches'] else []
|
||||
tags = [x['Revision'] for x in info['RevisionMap']['Tags']
|
||||
] if info['RevisionMap']['Tags'] else []
|
||||
return branches, tags
|
||||
|
||||
def get_model_files(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = 'master',
|
||||
root: Optional[str] = None,
|
||||
recursive: Optional[str] = False,
|
||||
use_cookies: Union[bool, CookieJar] = False) -> List[dict]:
|
||||
|
||||
cookies = None
|
||||
if isinstance(use_cookies, CookieJar):
|
||||
cookies = use_cookies
|
||||
elif use_cookies:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}'
|
||||
if root is not None:
|
||||
path = path + f'&Root={root}'
|
||||
|
||||
r = requests.get(path, cookies=cookies)
|
||||
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
|
||||
files = []
|
||||
for file in d['Data']['Files']:
|
||||
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
|
||||
continue
|
||||
|
||||
files.append(file)
|
||||
return files
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser('~/.modelscope/credentials')
|
||||
os.makedirs(path_credential, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def save_cookies(cls, cookies: CookieJar):
|
||||
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f:
|
||||
pickle.dump(cookies, f)
|
||||
|
||||
@classmethod
|
||||
def get_cookies(cls):
|
||||
try:
|
||||
with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f:
|
||||
return pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
logger.warn("Auth token does not exist, you'll get authentication \
|
||||
error when downloading private model files. Please login first"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_token(cls, token: str):
|
||||
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f:
|
||||
f.write(token)
|
||||
|
||||
@classmethod
|
||||
def get_token(cls) -> Optional[str]:
|
||||
"""
|
||||
Get token or None if not existent.
|
||||
|
||||
Returns:
|
||||
`str` or `None`: The token, `None` if it doesn't exist.
|
||||
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
with open(os.path.join(cls.path_credential, 'token'), 'r') as f:
|
||||
token = f.read()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def write_to_git_credential(username: str, password: str):
|
||||
with subprocess.Popen(
|
||||
'git credential-store store'.split(),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as process:
|
||||
input_username = f'username={username.lower()}'
|
||||
input_password = f'password={password}'
|
||||
|
||||
process.stdin.write(
|
||||
f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n'
|
||||
.encode('utf-8'))
|
||||
process.stdin.flush()
|
||||
8
modelscope/hub/constants.py
Normal file
8
modelscope/hub/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
MODELSCOPE_URL_SCHEME = 'http://'
|
||||
DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330'
|
||||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102'
|
||||
|
||||
DEFAULT_MODELSCOPE_GROUP = 'damo'
|
||||
MODEL_ID_SEPARATOR = '/'
|
||||
|
||||
LOGGER_NAME = 'ModelScopeHub'
|
||||
30
modelscope/hub/errors.py
Normal file
30
modelscope/hub/errors.py
Normal file
@@ -0,0 +1,30 @@
|
||||
class NotExistError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def is_ok(rsp):
|
||||
""" Check the request is ok
|
||||
|
||||
Args:
|
||||
rsp (_type_): The request response body
|
||||
Failed: {'Code': 10010101004, 'Message': 'get model info failed, err: unauthorized permission',
|
||||
'RequestId': '', 'Success': False}
|
||||
Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True}
|
||||
"""
|
||||
return rsp['Code'] == 200 and rsp['Success']
|
||||
|
||||
|
||||
def raise_on_error(rsp):
|
||||
"""If response error, raise exception
|
||||
|
||||
Args:
|
||||
rsp (_type_): The server response
|
||||
"""
|
||||
if rsp['Code'] == 200 and rsp['Success']:
|
||||
return True
|
||||
else:
|
||||
raise RequestError(rsp['Message'])
|
||||
254
modelscope/hub/file_download.py
Normal file
254
modelscope/hub/file_download.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import copy
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import json
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from requests.exceptions import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME,
|
||||
MODEL_ID_SEPARATOR)
|
||||
from .errors import NotExistError, RequestError, raise_on_error
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (get_cache_dir, get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
|
||||
SESSION_ID = uuid4().hex
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def model_file_download(
|
||||
model_id: str,
|
||||
file_path: str,
|
||||
revision: Optional[str] = 'master',
|
||||
cache_dir: Optional[str] = None,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
"""
|
||||
Download from a given URL and cache it if it's not already present in the
|
||||
local cache.
|
||||
|
||||
Given a URL, this function looks for the corresponding file in the local
|
||||
cache. If it's not there, download it. Then return the path to the cached
|
||||
file.
|
||||
|
||||
Args:
|
||||
model_id (`str`):
|
||||
The model to whom the file to be downloaded belongs.
|
||||
file_path(`str`):
|
||||
Path of the file to be downloaded, relative to the root of model repo
|
||||
revision(`str`, *optional*):
|
||||
revision of the model file to be downloaded.
|
||||
Can be any of a branch, tag or commit hash, default to `master`
|
||||
cache_dir (`str`, `Path`, *optional*):
|
||||
Path to the folder where cached files are stored.
|
||||
user_agent (`dict`, `str`, *optional*):
|
||||
The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
if `False`, download the file anyway even it exists
|
||||
|
||||
Returns:
|
||||
Local path (string) of file or if networking is off, last version of
|
||||
file cached on disk.
|
||||
|
||||
<Tip>
|
||||
|
||||
Raises the following errors:
|
||||
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
||||
if ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
|
||||
# if local_files_only is `True` and the file already exists in cached_path
|
||||
# return the cached path
|
||||
if local_files_only:
|
||||
cached_file_path = cache.get_file_by_path(file_path)
|
||||
if cached_file_path is not None:
|
||||
logger.warning(
|
||||
"File exists in local cache, but we're not sure it's up to date"
|
||||
)
|
||||
return cached_file_path
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
|
||||
_api = HubApi()
|
||||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
|
||||
branches, tags = _api.get_model_branches_and_tags(model_id)
|
||||
file_to_download_info = None
|
||||
is_commit_id = False
|
||||
if revision in branches or revision in tags: # The revision is version or tag,
|
||||
# we need to confirm the version is up to date
|
||||
# we need to get the file list to check if the lateast version is cached, if so return, otherwise download
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if model_file['Path'] == file_path:
|
||||
model_file['Branch'] = revision
|
||||
if cache.exists(model_file):
|
||||
return cache.get_file_by_info(model_file)
|
||||
else:
|
||||
file_to_download_info = model_file
|
||||
|
||||
if file_to_download_info is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
(file_path, model_id))
|
||||
else: # the revision is commit id.
|
||||
cached_file_path = cache.get_file_by_path_and_commit_id(
|
||||
file_path, revision)
|
||||
if cached_file_path is not None:
|
||||
logger.info('The specified file is in cache, skip downloading!')
|
||||
return cached_file_path # the file is in cache.
|
||||
is_commit_id = True
|
||||
# we need to download again
|
||||
# TODO: skip using JWT for authorization, use cookie instead
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
url_to_download = get_file_download_url(model_id, file_path, revision)
|
||||
file_to_download_info = {
|
||||
'Path': file_path,
|
||||
'Revision':
|
||||
revision if is_commit_id else file_to_download_info['Revision']
|
||||
}
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache.get_root_location() + '.lock'
|
||||
|
||||
with FileLock(lock_path):
|
||||
temp_file_name = next(tempfile._get_candidate_names())
|
||||
http_get_file(
|
||||
url_to_download,
|
||||
cache_dir,
|
||||
temp_file_name,
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
return cache.put_file(file_to_download_info,
|
||||
os.path.join(cache_dir, temp_file_name))
|
||||
|
||||
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||
"""Formats a user-agent string with basic info about a request.
|
||||
|
||||
Args:
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string.
|
||||
|
||||
Returns:
|
||||
The formatted user-agent string.
|
||||
"""
|
||||
ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}'
|
||||
|
||||
if isinstance(user_agent, dict):
|
||||
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua = user_agent
|
||||
return ua
|
||||
|
||||
|
||||
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
"""
|
||||
Format file download url according to `model_id`, `revision` and `file_path`.
|
||||
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
||||
the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
||||
"""
|
||||
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
||||
return download_url_template.format(
|
||||
endpoint=get_endpoint(),
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
|
||||
def http_get_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
cookies: Dict[str, str],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
This method is only used by snapshot_download, since the behavior is quite different with single file download
|
||||
TODO: consolidate with http_get_file() to avoild duplicate code
|
||||
|
||||
Args:
|
||||
url(`str`):
|
||||
actual download url of the file
|
||||
local_dir(`str`):
|
||||
local directory where the downloaded file stores
|
||||
file_name(`str`):
|
||||
name of the file stored in `local_dir`
|
||||
cookies(`Dict[str, str]`):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(`Optional[Dict[str, str]] = None`):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
|
||||
"""
|
||||
temp_file_manager = partial(
|
||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
||||
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info('downloading %s to %s', url, temp_file.name)
|
||||
headers = copy.deepcopy(headers)
|
||||
|
||||
r = requests.get(url, stream=True, headers=headers, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
|
||||
content_length = r.headers.get('Content-Length')
|
||||
total = int(content_length) if content_length is not None else None
|
||||
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
logger.info('storing %s in cache at %s', url, local_dir)
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
82
modelscope/hub/git.py
Normal file
82
modelscope/hub/git.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from threading import local
|
||||
from tkinter.messagebox import NO
|
||||
from typing import Union
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import LOGGER_NAME
|
||||
from .utils._subprocess import run_subprocess
|
||||
|
||||
logger = get_logger
|
||||
|
||||
|
||||
def git_clone(
|
||||
local_dir: str,
|
||||
repo_url: str,
|
||||
):
|
||||
# TODO: use "git clone" or "git lfs clone" according to git version
|
||||
# TODO: print stderr when subprocess fails
|
||||
run_subprocess(
|
||||
f'git clone {repo_url}'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_checkout(
|
||||
local_dir: str,
|
||||
revsion: str,
|
||||
):
|
||||
run_subprocess(f'git checkout {revsion}'.split(), local_dir)
|
||||
|
||||
|
||||
def git_add(local_dir: str, ):
|
||||
run_subprocess(
|
||||
'git add .'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_commit(local_dir: str, commit_message: str):
|
||||
run_subprocess(
|
||||
'git commit -v -m'.split() + [commit_message],
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_push(local_dir: str, branch: str):
|
||||
# check current branch
|
||||
cur_branch = git_current_branch(local_dir)
|
||||
if cur_branch != branch:
|
||||
logger.error(
|
||||
"You're trying to push to a different branch, please double check")
|
||||
return
|
||||
|
||||
run_subprocess(
|
||||
f'git push origin {branch}'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_current_branch(local_dir: str) -> Union[str, None]:
|
||||
"""
|
||||
Get current branch name
|
||||
|
||||
Args:
|
||||
local_dir(`str`): local model repo directory
|
||||
|
||||
Returns
|
||||
branch name you're currently on
|
||||
"""
|
||||
try:
|
||||
process = run_subprocess(
|
||||
'git rev-parse --abbrev-ref HEAD'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
return str(process.stdout).strip()
|
||||
except Exception as e:
|
||||
raise e
|
||||
173
modelscope/hub/repository.py
Normal file
173
modelscope/hub/repository.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import ModelScopeConfig
|
||||
from .constants import MODELSCOPE_URL_SCHEME
|
||||
from .git import git_add, git_checkout, git_clone, git_commit, git_push
|
||||
from .utils._subprocess import run_subprocess
|
||||
from .utils.utils import get_gitlab_domain
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Repository:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_dir: str,
|
||||
clone_from: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
private: Optional[bool] = False,
|
||||
revision: Optional[str] = 'master',
|
||||
):
|
||||
"""
|
||||
Instantiate a Repository object by cloning the remote ModelScopeHub repo
|
||||
Args:
|
||||
local_dir(`str`):
|
||||
local directory to store the model files
|
||||
clone_from(`Optional[str] = None`):
|
||||
model id in ModelScope-hub from which git clone
|
||||
You should ignore this parameter when `local_dir` is already a git repo
|
||||
auth_token(`Optional[str]`):
|
||||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
|
||||
as the token is already saved when you login the first time
|
||||
private(`Optional[bool]`):
|
||||
whether the model is private, default to False
|
||||
revision(`Optional[str]`):
|
||||
revision of the model you want to clone from. Can be any of a branch, tag or commit hash
|
||||
"""
|
||||
logger.info('Instantiating Repository object...')
|
||||
|
||||
# Create local directory if not exist
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
self.local_dir = os.path.join(os.getcwd(), local_dir)
|
||||
|
||||
self.private = private
|
||||
|
||||
# Check git and git-lfs installation
|
||||
self.check_git_versions()
|
||||
|
||||
# Retrieve auth token
|
||||
if not private and isinstance(auth_token, str):
|
||||
logger.warning(
|
||||
'cloning a public repo with a token, which will be ignored')
|
||||
self.token = None
|
||||
else:
|
||||
if isinstance(auth_token, str):
|
||||
self.token = auth_token
|
||||
else:
|
||||
self.token = ModelScopeConfig.get_token()
|
||||
|
||||
if self.token is None:
|
||||
raise EnvironmentError(
|
||||
'Token does not exist, the clone will fail for private repo.'
|
||||
'Please login first.')
|
||||
|
||||
# git clone
|
||||
if clone_from is not None:
|
||||
self.model_id = clone_from
|
||||
logger.info('cloning model repo to %s ...', self.local_dir)
|
||||
git_clone(self.local_dir, self.get_repo_url())
|
||||
else:
|
||||
if is_git_repo(self.local_dir):
|
||||
logger.debug('[Repository] is a valid git repo')
|
||||
else:
|
||||
raise ValueError(
|
||||
'If not specifying `clone_from`, you need to pass Repository a'
|
||||
' valid git clone.')
|
||||
|
||||
# git checkout
|
||||
if isinstance(revision, str) and revision != 'master':
|
||||
git_checkout(revision)
|
||||
|
||||
def push_to_hub(self,
|
||||
commit_message: str,
|
||||
revision: Optional[str] = 'master'):
|
||||
"""
|
||||
Push changes changes to hub
|
||||
|
||||
Args:
|
||||
commit_message(`str`):
|
||||
commit message describing the changes, it's mandatory
|
||||
revision(`Optional[str]`):
|
||||
remote branch you want to push to, default to `master`
|
||||
|
||||
<Tip>
|
||||
The function complains when local and remote branch are different, please be careful
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
git_add(self.local_dir)
|
||||
git_commit(self.local_dir, commit_message)
|
||||
|
||||
logger.info('Pushing changes to repo...')
|
||||
git_push(self.local_dir, revision)
|
||||
|
||||
# TODO: if git push fails, how to retry?
|
||||
|
||||
def check_git_versions(self):
|
||||
"""
|
||||
Checks that `git` and `git-lfs` can be run.
|
||||
|
||||
Raises:
|
||||
`EnvironmentError`: if `git` or `git-lfs` are not installed.
|
||||
"""
|
||||
try:
|
||||
git_version = run_subprocess('git --version'.split(),
|
||||
self.local_dir).stdout.strip()
|
||||
except FileNotFoundError:
|
||||
raise EnvironmentError(
|
||||
'Looks like you do not have git installed, please install.')
|
||||
|
||||
try:
|
||||
lfs_version = run_subprocess('git-lfs --version'.split(),
|
||||
self.local_dir).stdout.strip()
|
||||
except FileNotFoundError:
|
||||
raise EnvironmentError(
|
||||
'Looks like you do not have git-lfs installed, please install.'
|
||||
' You can install from https://git-lfs.github.com/.'
|
||||
' Then run `git lfs install` (you only have to do this once).')
|
||||
logger.info(git_version + '\n' + lfs_version)
|
||||
|
||||
def get_repo_url(self) -> str:
|
||||
"""
|
||||
Get repo url to clone, according whether the repo is private or not
|
||||
"""
|
||||
url = None
|
||||
|
||||
if self.private:
|
||||
url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}'
|
||||
else:
|
||||
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}'
|
||||
|
||||
if not url:
|
||||
raise ValueError(
|
||||
'Empty repo url, please check clone_from parameter')
|
||||
|
||||
logger.debug('url to clone: %s', str(url))
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def is_git_repo(folder: Union[str, Path]) -> bool:
|
||||
"""
|
||||
Check if the folder is the root or part of a git repository
|
||||
|
||||
Args:
|
||||
folder (`str`):
|
||||
The folder in which to run the command.
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if the repository is part of a repository, `False`
|
||||
otherwise.
|
||||
"""
|
||||
folder_exists = os.path.exists(os.path.join(folder, '.git'))
|
||||
git_branch = subprocess.run(
|
||||
'git branch'.split(),
|
||||
cwd=folder,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
return folder_exists and git_branch.returncode == 0
|
||||
125
modelscope/hub/snapshot_download.py
Normal file
125
modelscope/hub/snapshot_download.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import tempfile
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR
|
||||
from .errors import NotExistError, RequestError, raise_on_error
|
||||
from .file_download import (get_file_download_url, http_get_file,
|
||||
http_user_agent)
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import get_cache_dir, model_id_to_group_owner_name
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def snapshot_download(model_id: str,
|
||||
revision: Optional[str] = 'master',
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
private: Optional[bool] = False) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
is useful when you want all files from a repo, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a repo but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
Args:
|
||||
model_id (`str`):
|
||||
A user or an organization name and a repo name separated by a `/`.
|
||||
revision (`str`, *optional*):
|
||||
An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (`str`, `Path`, *optional*):
|
||||
Path to the folder where cached files are stored.
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
Returns:
|
||||
Local folder path (string) of repo snapshot
|
||||
|
||||
<Tip>
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
logger.warn('We can not confirm the cached file is for revision: %s'
|
||||
% revision)
|
||||
return cache.get_root_location(
|
||||
) # we can not confirm the cached file is for snapshot 'revision'
|
||||
else:
|
||||
# make headers
|
||||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
|
||||
_api = HubApi()
|
||||
# get file list from model repo
|
||||
branches, tags = _api.get_model_branches_and_tags(model_id)
|
||||
if revision not in branches and revision not in tags:
|
||||
raise NotExistError('The specified branch or tag : %s not exist!'
|
||||
% revision)
|
||||
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=private)
|
||||
|
||||
cookies = None
|
||||
if private:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
# check model_file is exist in cache, if exist, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
logger.info(
|
||||
'The specified file is in cache, skip downloading!')
|
||||
continue
|
||||
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
|
||||
# First download to /tmp
|
||||
http_get_file(
|
||||
url=url,
|
||||
local_dir=tempfile.gettempdir(),
|
||||
file_name=model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
# put file to cache
|
||||
cache.put_file(
|
||||
model_file,
|
||||
os.path.join(tempfile.gettempdir(), model_file['Name']))
|
||||
|
||||
return os.path.join(cache.get_root_location())
|
||||
0
modelscope/hub/utils/__init__.py
Normal file
0
modelscope/hub/utils/__init__.py
Normal file
40
modelscope/hub/utils/_subprocess.py
Normal file
40
modelscope/hub/utils/_subprocess.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import subprocess
|
||||
from typing import List
|
||||
|
||||
|
||||
def run_subprocess(command: List[str],
|
||||
folder: str,
|
||||
check=True,
|
||||
**kwargs) -> subprocess.CompletedProcess:
|
||||
"""
|
||||
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
|
||||
please call `subprocess.run` manually in case you would like for them not to
|
||||
be captured.
|
||||
|
||||
Args:
|
||||
command (`List[str]`):
|
||||
The command to execute as a list of strings.
|
||||
folder (`str`):
|
||||
The folder in which to run the command.
|
||||
check (`bool`, *optional*, defaults to `True`):
|
||||
Setting `check` to `True` will raise a `subprocess.CalledProcessError`
|
||||
when the subprocess has a non-zero exit code.
|
||||
kwargs (`Dict[str]`):
|
||||
Keyword arguments to be passed to the `subprocess.run` underlying command.
|
||||
|
||||
Returns:
|
||||
`subprocess.CompletedProcess`: The completed process.
|
||||
"""
|
||||
if isinstance(command, str):
|
||||
raise ValueError(
|
||||
'`run_subprocess` should be called with a list of strings.')
|
||||
|
||||
return subprocess.run(
|
||||
command,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
check=check,
|
||||
encoding='utf-8',
|
||||
cwd=folder,
|
||||
**kwargs,
|
||||
)
|
||||
294
modelscope/hub/utils/caching.py
Normal file
294
modelscope/hub/utils/caching.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import time
|
||||
from shutil import move, rmtree
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class FileSystemCache(object):
|
||||
KEY_FILE_NAME = '.msc'
|
||||
"""Local file cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_root_location: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
cache_location: str
|
||||
The root location to store files.
|
||||
"""
|
||||
os.makedirs(cache_root_location, exist_ok=True)
|
||||
self.cache_root_location = cache_root_location
|
||||
self.load_cache()
|
||||
|
||||
def get_root_location(self):
|
||||
return self.cache_root_location
|
||||
|
||||
def load_cache(self):
|
||||
"""Read set of stored blocks from file
|
||||
Args:
|
||||
owner(`str`): individual or group username at modelscope, can be empty for official models
|
||||
name(`str`): name of the model
|
||||
Returns:
|
||||
The model details information.
|
||||
Raises:
|
||||
NotExistError: If the model is not exist, will throw NotExistError
|
||||
TODO: Error based error code.
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
self.cached_files = []
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
if os.path.exists(cache_keys_file_path):
|
||||
with open(cache_keys_file_path, 'rb') as f:
|
||||
self.cached_files = pickle.load(f)
|
||||
|
||||
def save_cached_files(self):
|
||||
"""Save cache metadata."""
|
||||
# save new meta to tmp and move to KEY_FILE_NAME
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
# TODO: Sync file write
|
||||
fd, fn = tempfile.mkstemp()
|
||||
with open(fd, 'wb') as f:
|
||||
pickle.dump(self.cached_files, f)
|
||||
move(fn, cache_keys_file_path)
|
||||
|
||||
def get_file(self, key):
|
||||
"""Check the key is in the cache, if exist, return the file, otherwise return None.
|
||||
Args:
|
||||
key(`str`): The cache key.
|
||||
Returns:
|
||||
If file exist, return the cached file location, otherwise None.
|
||||
Raises:
|
||||
None
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
pass
|
||||
|
||||
def put_file(self, key, location):
|
||||
"""Put file to the cache,
|
||||
Args:
|
||||
key(`str`): The cache key
|
||||
location(`str`): Location of the file, we will move the file to cache.
|
||||
Returns:
|
||||
The cached file path of the file.
|
||||
Raises:
|
||||
None
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
pass
|
||||
|
||||
def remove_key(self, key):
|
||||
"""Remove cache key in index, The file is removed manually
|
||||
|
||||
Args:
|
||||
key (dict): The cache key.
|
||||
"""
|
||||
self.cached_files.remove(key)
|
||||
self.save_cached_files()
|
||||
|
||||
def exists(self, key):
|
||||
for cache_file in self.cached_files:
|
||||
if cache_file == key:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def clear_cache(self):
|
||||
"""Remove all files and metadat from the cache
|
||||
|
||||
In the case of multiple cache locations, this clears only the last one,
|
||||
which is assumed to be the read/write one.
|
||||
"""
|
||||
rmtree(self.cache_root_location)
|
||||
self.load_cache()
|
||||
|
||||
def hash_name(self, key):
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
|
||||
class ModelFileSystemCache(FileSystemCache):
|
||||
"""Local cache file layout
|
||||
cache_root/owner/model_name/|individual cached files
|
||||
|.mk: file, The cache index file
|
||||
Save only one version for each file.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_root, owner, name):
|
||||
"""Put file to the cache
|
||||
Args:
|
||||
cache_root(`str`): The modelscope local cache root(default: ~/.modelscope/cache/models/)
|
||||
owner(`str`): The model owner.
|
||||
name('str'): The name of the model
|
||||
branch('str'): The branch of model
|
||||
tag('str'): The tag of model
|
||||
Returns:
|
||||
Raises:
|
||||
None
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
super().__init__(os.path.join(cache_root, owner, name))
|
||||
|
||||
def get_file_by_path(self, file_path):
|
||||
"""Retrieve the cache if there is file match the path.
|
||||
Args:
|
||||
file_path (str): The file path in the model.
|
||||
Returns:
|
||||
path: the full path of the file.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if file_path == cached_file['Path']:
|
||||
cached_file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(cached_file_path):
|
||||
return cached_file_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def get_file_by_path_and_commit_id(self, file_path, commit_id):
|
||||
"""Retrieve the cache if there is file match the path.
|
||||
Args:
|
||||
file_path (str): The file path in the model.
|
||||
commit_id (str): The commit id of the file
|
||||
Returns:
|
||||
path: the full path of the file.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if file_path == cached_file['Path'] and \
|
||||
(cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
|
||||
cached_file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(cached_file_path):
|
||||
return cached_file_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def get_file_by_info(self, model_file_info):
|
||||
"""Check if exist cache file.
|
||||
|
||||
Args:
|
||||
model_file_info (ModelFileInfo): The file information of the file.
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
cache_key = self.__get_cache_key(model_file_info)
|
||||
for cached_file in self.cached_files:
|
||||
if cached_file == cache_key:
|
||||
orig_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(orig_path):
|
||||
return orig_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def __get_cache_key(self, model_file_info):
|
||||
cache_key = {
|
||||
'Path': model_file_info['Path'],
|
||||
'Revision': model_file_info['Revision'], # commit id
|
||||
}
|
||||
return cache_key
|
||||
|
||||
def exists(self, model_file_info):
|
||||
"""Check the file is cached or not.
|
||||
|
||||
Args:
|
||||
model_file_info (CachedFileInfo): The cached file info
|
||||
|
||||
Returns:
|
||||
bool: If exists return True otherwise False
|
||||
"""
|
||||
key = self.__get_cache_key(model_file_info)
|
||||
is_exists = False
|
||||
for cached_key in self.cached_files:
|
||||
if cached_key['Path'] == key['Path'] and (
|
||||
cached_key['Revision'].startswith(key['Revision'])
|
||||
or key['Revision'].startswith(cached_key['Revision'])):
|
||||
is_exists = True
|
||||
file_path = os.path.join(self.cache_root_location,
|
||||
model_file_info['Path'])
|
||||
if is_exists:
|
||||
if os.path.exists(file_path):
|
||||
return True
|
||||
else:
|
||||
self.remove_key(
|
||||
model_file_info) # sameone may manual delete the file
|
||||
return False
|
||||
|
||||
def remove_if_exists(self, model_file_info):
|
||||
"""We in cache, remove it.
|
||||
|
||||
Args:
|
||||
model_file_info (ModelFileInfo): The model file information from server.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if cached_file['Path'] == model_file_info['Path']:
|
||||
self.remove_key(cached_file)
|
||||
file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
def put_file(self, model_file_info, model_file_location):
|
||||
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
|
||||
|
||||
Args:
|
||||
model_file_info (str): The file description returned by get_model_files
|
||||
sample:
|
||||
{
|
||||
"CommitMessage": "add model\n",
|
||||
"CommittedDate": 1654857567,
|
||||
"CommitterName": "mulin.lyh",
|
||||
"IsLFS": false,
|
||||
"Mode": "100644",
|
||||
"Name": "resnet18.pth",
|
||||
"Path": "resnet18.pth",
|
||||
"Revision": "09b68012b27de0048ba74003690a890af7aff192",
|
||||
"Size": 46827520,
|
||||
"Type": "blob"
|
||||
}
|
||||
model_file_location (str): The location of the temporary file.
|
||||
Raises:
|
||||
NotImplementedError: _description_
|
||||
|
||||
Returns:
|
||||
str: The location of the cached file.
|
||||
"""
|
||||
self.remove_if_exists(model_file_info) # backup old revision
|
||||
cache_key = self.__get_cache_key(model_file_info)
|
||||
cache_full_path = os.path.join(
|
||||
self.cache_root_location,
|
||||
cache_key['Path']) # Branch and Tag do not have same name.
|
||||
cache_file_dir = os.path.dirname(cache_full_path)
|
||||
if not os.path.exists(cache_file_dir):
|
||||
os.makedirs(cache_file_dir, exist_ok=True)
|
||||
# We can't make operation transaction
|
||||
move(model_file_location, cache_full_path)
|
||||
self.cached_files.append(cache_key)
|
||||
self.save_cached_files()
|
||||
return cache_full_path
|
||||
39
modelscope/hub/utils/utils.py
Normal file
39
modelscope/hub/utils/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
|
||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GROUP,
|
||||
MODEL_ID_SEPARATOR,
|
||||
MODELSCOPE_URL_SCHEME)
|
||||
|
||||
|
||||
def model_id_to_group_owner_name(model_id):
|
||||
if MODEL_ID_SEPARATOR in model_id:
|
||||
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
|
||||
name = model_id.split(MODEL_ID_SEPARATOR)[1]
|
||||
else:
|
||||
group_or_owner = DEFAULT_MODELSCOPE_GROUP
|
||||
name = model_id
|
||||
return group_or_owner, name
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
"""
|
||||
cache dir precedence:
|
||||
function parameter > enviroment > ~/.cache/modelscope/hub
|
||||
"""
|
||||
default_cache_dir = os.path.expanduser(
|
||||
os.path.join('~/.cache', 'modelscope'))
|
||||
return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir,
|
||||
'hub'))
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_DOMAIN)
|
||||
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||
|
||||
|
||||
def get_gitlab_domain():
|
||||
return os.getenv('MODELSCOPE_GITLAB_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN)
|
||||
94
modelscope/metainfo.py
Normal file
94
modelscope/metainfo.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
|
||||
class Models(object):
|
||||
""" Names for different models.
|
||||
|
||||
Holds the standard model name to use for identifying different model.
|
||||
This should be used to register models.
|
||||
|
||||
Model name should only contain model info but not task info.
|
||||
"""
|
||||
# vision models
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
palm2_0 = 'palm2.0'
|
||||
structbert = 'structbert'
|
||||
|
||||
# audio models
|
||||
sambert_hifi_16k = 'sambert-hifi-16k'
|
||||
generic_tts_frontend = 'generic-tts-frontend'
|
||||
hifigan16k = 'hifigan16k'
|
||||
|
||||
# multi-modal models
|
||||
ofa = 'ofa'
|
||||
|
||||
|
||||
class Pipelines(object):
|
||||
""" Names for different pipelines.
|
||||
|
||||
Holds the standard pipline name to use for identifying different pipeline.
|
||||
This should be used to register pipelines.
|
||||
|
||||
For pipeline which support different models and implements the common function, we
|
||||
should use task name for this pipeline.
|
||||
For pipeline which suuport only one model, we should use ${Model}-${Task} as its name.
|
||||
"""
|
||||
# vision tasks
|
||||
image_matting = 'unet-image-matting'
|
||||
person_image_cartoon = 'unet-person-image-cartoon'
|
||||
ocr_detection = 'resnet18-ocr-detection'
|
||||
|
||||
# nlp tasks
|
||||
sentence_similarity = 'sentence-similarity'
|
||||
word_segmentation = 'word-segmentation'
|
||||
text_generation = 'text-generation'
|
||||
sentiment_analysis = 'sentiment-analysis'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
|
||||
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
|
||||
|
||||
# multi-modal tasks
|
||||
image_caption = 'image-caption'
|
||||
|
||||
|
||||
class Trainers(object):
|
||||
""" Names for different trainer.
|
||||
|
||||
Holds the standard trainer name to use for identifying different trainer.
|
||||
This should be used to register trainers.
|
||||
|
||||
For a general Trainer, you can use easynlp-trainer/ofa-trainer/sofa-trainer.
|
||||
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
|
||||
"""
|
||||
|
||||
default = 'Trainer'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
""" Names for different preprocessor.
|
||||
|
||||
Holds the standard preprocessor name to use for identifying different preprocessor.
|
||||
This should be used to register preprocessors.
|
||||
|
||||
For a general preprocessor, just use the function name as preprocessor name such as
|
||||
resize-image, random-crop
|
||||
For a model-specific preprocessor, use ${modelname}-${fuction}
|
||||
"""
|
||||
|
||||
# cv preprocessor
|
||||
load_image = 'load-image'
|
||||
|
||||
# nlp preprocessor
|
||||
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
|
||||
palm_text_gen_tokenizer = 'palm-text-gen-tokenizer'
|
||||
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
|
||||
|
||||
# audio preprocessor
|
||||
linear_aec_fbank = 'linear-aec-fbank'
|
||||
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
|
||||
|
||||
# multi-modal
|
||||
ofa_image_caption = 'ofa-image-caption'
|
||||
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
from sklearn.preprocessing import MultiLabelBinarizer
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
@@ -26,7 +27,8 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol):
|
||||
return one_hot.fit_transform(sequences)
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_to_speech, module_name=r'sambert_hifi_16k')
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_speech, module_name=Models.sambert_hifi_16k)
|
||||
class SambertNetHifi16k(Model):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import zipfile
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.audio.tts_exceptions import (
|
||||
@@ -13,7 +14,7 @@ __all__ = ['GenericTtsFrontend']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_speech, module_name=r'generic_tts_frontend')
|
||||
Tasks.text_to_speech, module_name=Models.generic_tts_frontend)
|
||||
class GenericTtsFrontend(Model):
|
||||
|
||||
def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs):
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.audio.tts_exceptions import \
|
||||
@@ -36,7 +37,7 @@ class AttrDict(dict):
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_to_speech, module_name=r'hifigan16k')
|
||||
@MODELS.register_module(Tasks.text_to_speech, module_name=Models.hifigan16k)
|
||||
class Hifigan16k(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
|
||||
@@ -4,12 +4,13 @@ import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
|
||||
@@ -47,21 +48,25 @@ class Model(ABC):
|
||||
if osp.exists(model_name_or_path):
|
||||
local_model_dir = model_name_or_path
|
||||
else:
|
||||
cache_path = get_model_cache_dir(model_name_or_path)
|
||||
local_model_dir = cache_path if osp.exists(
|
||||
cache_path) else snapshot_download(model_name_or_path)
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# 'Remote model repo {model_name_or_path} does not exists')
|
||||
|
||||
local_model_dir = snapshot_download(model_name_or_path)
|
||||
logger.info(f'initialize model from {local_model_dir}')
|
||||
cfg = Config.from_file(
|
||||
osp.join(local_model_dir, ModelFile.CONFIGURATION))
|
||||
task_name = cfg.task
|
||||
model_cfg = cfg.model
|
||||
assert hasattr(
|
||||
cfg, 'pipeline'), 'pipeline config is missing from config file.'
|
||||
pipeline_cfg = cfg.pipeline
|
||||
# TODO @wenmeng.zwm may should manually initialize model after model building
|
||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
|
||||
model_cfg.model_dir = local_model_dir
|
||||
|
||||
for k, v in kwargs.items():
|
||||
model_cfg.k = v
|
||||
return build_model(model_cfg, task_name)
|
||||
model = build_model(model_cfg, task_name)
|
||||
|
||||
# dynamically add pipeline info to model for pipeline inference
|
||||
model.pipeline = pipeline_cfg
|
||||
return model
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from ..base import Model
|
||||
from ..builder import MODELS
|
||||
@@ -10,8 +11,7 @@ from ..builder import MODELS
|
||||
__all__ = ['OfaForImageCaptioning']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_captioning, module_name=r'ofa-image-captioning')
|
||||
@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa)
|
||||
class OfaForImageCaptioning(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Dict
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Model
|
||||
from ..builder import MODELS
|
||||
@@ -11,8 +12,7 @@ from ..builder import MODELS
|
||||
__all__ = ['BertForSequenceClassification']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
|
||||
@MODELS.register_module(Tasks.text_classification, module_name=Models.bert)
|
||||
class BertForSequenceClassification(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Dict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Model, Tensor
|
||||
from ..builder import MODELS
|
||||
@@ -7,7 +8,7 @@ from ..builder import MODELS
|
||||
__all__ = ['PalmForTextGeneration']
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0')
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=Models.palm2_0)
|
||||
class PalmForTextGeneration(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
|
||||
@@ -8,6 +8,7 @@ from sofa import SbertModel
|
||||
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
|
||||
from torch import nn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Model, Tensor
|
||||
from ..builder import MODELS
|
||||
@@ -38,8 +39,7 @@ class SbertTextClassifier(SbertPreTrainedModel):
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.sentence_similarity,
|
||||
module_name=r'sbert-base-chinese-sentence-similarity')
|
||||
Tasks.sentence_similarity, module_name=Models.structbert)
|
||||
class SbertForSentenceSimilarity(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
from sofa import SbertConfig, SbertForTokenClassification
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Model, Tensor
|
||||
from ..builder import MODELS
|
||||
@@ -11,9 +12,7 @@ from ..builder import MODELS
|
||||
__all__ = ['StructBertForTokenClassification']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.word_segmentation,
|
||||
module_name=r'structbert-chinese-word-segmentation')
|
||||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert)
|
||||
class StructBertForTokenClassification(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
|
||||
@@ -7,6 +7,7 @@ import scipy.io.wavfile as wav
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.preprocessors.audio import LinearAECAndFbank
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from ..base import Pipeline
|
||||
@@ -39,7 +40,8 @@ def initialize_config(module_cfg):
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_signal_process, module_name=r'speech_dfsmn_aec_psm_16k')
|
||||
Tasks.speech_signal_process,
|
||||
module_name=Pipelines.speech_dfsmn_aec_psm_16k)
|
||||
class LinearAECPipeline(Pipeline):
|
||||
r"""AEC Inference Pipeline only support 16000 sample rate.
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.audio.tts.am import SambertNetHifi16k
|
||||
from modelscope.models.audio.tts.vocoder import Hifigan16k
|
||||
@@ -15,7 +16,7 @@ __all__ = ['TextToSpeechSambertHifigan16kPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_to_speech, module_name=r'tts-sambert-hifigan-16k')
|
||||
Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_16k_tts)
|
||||
class TextToSpeechSambertHifigan16kPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -4,16 +4,14 @@ import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.pydatasets import PyDataset
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .outputs import TASK_OUTPUTS
|
||||
from .util import is_model_name
|
||||
from .util import is_model, is_official_hub_path
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
|
||||
@@ -29,14 +27,10 @@ class Pipeline(ABC):
|
||||
|
||||
def initiate_single_model(self, model):
|
||||
logger.info(f'initiate model from {model}')
|
||||
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model
|
||||
if isinstance(model, str) and model.startswith('damo/'):
|
||||
if not osp.exists(model):
|
||||
cache_path = get_model_cache_dir(model)
|
||||
model = cache_path if osp.exists(
|
||||
cache_path) else snapshot_download(model)
|
||||
return Model.from_pretrained(model) if is_model_name(
|
||||
model) else model
|
||||
if isinstance(model, str) and is_official_hub_path(model):
|
||||
model = snapshot_download(
|
||||
model) if not osp.exists(model) else model
|
||||
return Model.from_pretrained(model) if is_model(model) else model
|
||||
elif isinstance(model, Model):
|
||||
return model
|
||||
else:
|
||||
|
||||
@@ -3,32 +3,39 @@
|
||||
import os.path as osp
|
||||
from typing import List, Union
|
||||
|
||||
from attr import has
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
from .base import Pipeline
|
||||
from .util import is_official_hub_path
|
||||
|
||||
PIPELINES = Registry('pipelines')
|
||||
|
||||
DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
# TaskName: (pipeline_module_name, model_repo)
|
||||
Tasks.word_segmentation:
|
||||
('structbert-chinese-word-segmentation',
|
||||
(Pipelines.word_segmentation,
|
||||
'damo/nlp_structbert_word-segmentation_chinese-base'),
|
||||
Tasks.sentence_similarity:
|
||||
('sbert-base-chinese-sentence-similarity',
|
||||
(Pipelines.sentence_similarity,
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-base'),
|
||||
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
|
||||
Tasks.text_classification:
|
||||
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
|
||||
Tasks.text_generation: ('palm2.0',
|
||||
Tasks.image_matting:
|
||||
(Pipelines.image_matting, 'damo/cv_unet_image-matting'),
|
||||
Tasks.text_classification: (Pipelines.sentiment_analysis,
|
||||
'damo/bert-base-sst2'),
|
||||
Tasks.text_generation: (Pipelines.text_generation,
|
||||
'damo/nlp_palm2.0_text-generation_chinese-base'),
|
||||
Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'),
|
||||
Tasks.image_captioning: (Pipelines.image_caption,
|
||||
'damo/ofa_image-caption_coco_large_en'),
|
||||
Tasks.image_generation:
|
||||
('person-image-cartoon',
|
||||
(Pipelines.person_image_cartoon,
|
||||
'damo/cv_unet_person-image-cartoon_compound-models'),
|
||||
Tasks.ocr_detection: ('ocr-detection',
|
||||
Tasks.ocr_detection: (Pipelines.ocr_detection,
|
||||
'damo/cv_resnet18_ocr-detection-line-level_damo'),
|
||||
Tasks.fill_mask: ('veco', 'damo/nlp_veco_fill-mask_large')
|
||||
}
|
||||
@@ -87,30 +94,40 @@ def pipeline(task: str = None,
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
|
||||
assert isinstance(model, (type(None), str, Model, list)), \
|
||||
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}'
|
||||
|
||||
if pipeline_name is None:
|
||||
# get default pipeline for this task
|
||||
if isinstance(model, str) \
|
||||
or (isinstance(model, list) and isinstance(model[0], str)):
|
||||
|
||||
# if is_model_name(model):
|
||||
if (isinstance(model, str) and model.startswith('damo/')) \
|
||||
or (isinstance(model, list) and model[0].startswith('damo/')) \
|
||||
or (isinstance(model, str) and osp.exists(model)):
|
||||
# TODO @wenmeng.zwm add support when model is a str of modelhub address
|
||||
# read pipeline info from modelhub configuration file.
|
||||
pipeline_name, default_model_repo = get_default_pipeline_info(
|
||||
task)
|
||||
if is_official_hub_path(model):
|
||||
# read config file from hub and parse
|
||||
cfg = read_config(model) if isinstance(
|
||||
model, str) else read_config(model[0])
|
||||
assert hasattr(
|
||||
cfg,
|
||||
'pipeline'), 'pipeline config is missing from config file.'
|
||||
pipeline_name = cfg.pipeline.type
|
||||
else:
|
||||
# used for test case, when model is str and is not hub path
|
||||
pipeline_name = get_pipeline_by_model_name(task, model)
|
||||
elif isinstance(model, Model) or \
|
||||
(isinstance(model, list) and isinstance(model[0], Model)):
|
||||
# get pipeline info from Model object
|
||||
first_model = model[0] if isinstance(model, list) else model
|
||||
if not hasattr(first_model, 'pipeline'):
|
||||
# model is instantiated by user, we should parse config again
|
||||
cfg = read_config(first_model.model_dir)
|
||||
assert hasattr(
|
||||
cfg,
|
||||
'pipeline'), 'pipeline config is missing from config file.'
|
||||
first_model.pipeline = cfg.pipeline
|
||||
pipeline_name = first_model.pipeline.type
|
||||
else:
|
||||
pipeline_name, default_model_repo = get_default_pipeline_info(task)
|
||||
|
||||
if model is None:
|
||||
model = default_model_repo
|
||||
|
||||
assert isinstance(model, (type(None), str, Model, list)), \
|
||||
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}'
|
||||
|
||||
cfg = ConfigDict(type=pipeline_name, model=model)
|
||||
|
||||
if kwargs:
|
||||
|
||||
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import PIL
|
||||
import tensorflow as tf
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.cartoon.facelib.facer import FaceAna
|
||||
from modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans import (
|
||||
get_reference_facial_points, warp_and_crop_face)
|
||||
@@ -25,7 +26,7 @@ logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_generation, module_name='person-image-cartoon')
|
||||
Tasks.image_generation, module_name=Pipelines.person_image_cartoon)
|
||||
class ImageCartoonPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str):
|
||||
|
||||
@@ -5,6 +5,7 @@ import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
@@ -16,7 +17,7 @@ logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_matting, module_name=Tasks.image_matting)
|
||||
Tasks.image_matting, module_name=Pipelines.image_matting)
|
||||
class ImageMattingPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str):
|
||||
|
||||
@@ -10,6 +10,7 @@ import PIL
|
||||
import tensorflow as tf
|
||||
import tf_slim as slim
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
@@ -38,7 +39,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6,
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.ocr_detection, module_name=Tasks.ocr_detection)
|
||||
Tasks.ocr_detection, module_name=Pipelines.ocr_detection)
|
||||
class OCRDetectionPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -9,7 +10,8 @@ from ..builder import PIPELINES
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa')
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_captioning, module_name=Pipelines.image_caption)
|
||||
class ImageCaptionPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.nlp import SbertForSentenceSimilarity
|
||||
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
@@ -13,8 +14,7 @@ __all__ = ['SentenceSimilarityPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.sentence_similarity,
|
||||
module_name=r'sbert-base-chinese-sentence-similarity')
|
||||
Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity)
|
||||
class SentenceSimilarityPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.nlp import BertForSequenceClassification
|
||||
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
@@ -13,7 +14,7 @@ __all__ = ['SequenceClassificationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
|
||||
Tasks.text_classification, module_name=Pipelines.sentiment_analysis)
|
||||
class SequenceClassificationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import PalmForTextGeneration
|
||||
from modelscope.preprocessors import TextGenerationPreprocessor
|
||||
@@ -10,7 +11,8 @@ from ..builder import PIPELINES
|
||||
__all__ = ['TextGenerationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0')
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_generation, module_name=Pipelines.text_generation)
|
||||
class TextGenerationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import StructBertForTokenClassification
|
||||
from modelscope.preprocessors import TokenClassifcationPreprocessor
|
||||
@@ -11,8 +12,7 @@ __all__ = ['WordSegmentationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.word_segmentation,
|
||||
module_name=r'structbert-chinese-word-segmentation')
|
||||
Tasks.word_segmentation, module_name=Pipelines.word_segmentation)
|
||||
class WordSegmentationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import os.path as osp
|
||||
from typing import List, Union
|
||||
|
||||
from maas_hub.file_download import model_file_download
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -20,31 +20,63 @@ def is_config_has_model(cfg_file):
|
||||
return False
|
||||
|
||||
|
||||
def is_model_name(model: Union[str, List]):
|
||||
""" whether model is a valid modelhub path
|
||||
def is_official_hub_path(path: Union[str, List]):
|
||||
""" Whether path is a official hub name or a valid local
|
||||
path to official hub directory.
|
||||
"""
|
||||
|
||||
def is_model_name_impl(model):
|
||||
if osp.exists(model):
|
||||
cfg_file = osp.join(model, ModelFile.CONFIGURATION)
|
||||
def is_official_hub_impl(path):
|
||||
if osp.exists(path):
|
||||
cfg_file = osp.join(path, ModelFile.CONFIGURATION)
|
||||
return osp.exists(cfg_file)
|
||||
else:
|
||||
try:
|
||||
_ = HubApi().get_model(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if isinstance(path, str):
|
||||
return is_official_hub_impl(path)
|
||||
else:
|
||||
results = [is_official_hub_impl(m) for m in path]
|
||||
all_true = all(results)
|
||||
any_true = any(results)
|
||||
if any_true and not all_true:
|
||||
raise ValueError(
|
||||
f'some model are hub address, some are not, model list: {path}'
|
||||
)
|
||||
|
||||
return all_true
|
||||
|
||||
|
||||
def is_model(path: Union[str, List]):
|
||||
""" whether path is a valid modelhub path and containing model config
|
||||
"""
|
||||
|
||||
def is_modelhub_path_impl(path):
|
||||
if osp.exists(path):
|
||||
cfg_file = osp.join(path, ModelFile.CONFIGURATION)
|
||||
if osp.exists(cfg_file):
|
||||
return is_config_has_model(cfg_file)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
cfg_file = model_file_download(model, ModelFile.CONFIGURATION)
|
||||
cfg_file = model_file_download(path, ModelFile.CONFIGURATION)
|
||||
return is_config_has_model(cfg_file)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if isinstance(model, str):
|
||||
return is_model_name_impl(model)
|
||||
if isinstance(path, str):
|
||||
return is_modelhub_path_impl(path)
|
||||
else:
|
||||
results = [is_model_name_impl(m) for m in model]
|
||||
results = [is_modelhub_path_impl(m) for m in path]
|
||||
all_true = all(results)
|
||||
any_true = any(results)
|
||||
if any_true and not all_true:
|
||||
raise ValueError('some model are hub address, some are not')
|
||||
raise ValueError(
|
||||
f'some models are hub address, some are not, model list: {path}'
|
||||
)
|
||||
|
||||
return all_true
|
||||
|
||||
@@ -5,11 +5,12 @@ from typing import Dict, Union
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.utils.constant import Fields
|
||||
from .builder import PREPROCESSORS
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(Fields.cv)
|
||||
@PREPROCESSORS.register_module(Fields.cv, Preprocessors.load_image)
|
||||
class LoadImage:
|
||||
"""Load an image from file or url.
|
||||
Added or updated keys are "filename", "img", "img_shape",
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.utils.constant import Fields, ModelFile
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
@@ -20,7 +20,7 @@ __all__ = [
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.multi_modal, module_name=r'ofa-image-caption')
|
||||
Fields.multi_modal, module_name=Preprocessors.ofa_image_caption)
|
||||
class OfaImageCaptionPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
@@ -34,9 +34,7 @@ class OfaImageCaptionPreprocessor(Preprocessor):
|
||||
if osp.exists(model_dir):
|
||||
local_model_dir = model_dir
|
||||
else:
|
||||
cache_path = get_model_cache_dir(model_dir)
|
||||
local_model_dir = cache_path if osp.exists(
|
||||
cache_path) else snapshot_download(model_dir)
|
||||
local_model_dir = snapshot_download(model_dir)
|
||||
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
bpe_dir = local_model_dir
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Union
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.utils.constant import Fields, InputFields
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
from .base import Preprocessor
|
||||
@@ -32,7 +33,7 @@ class Tokenize(Preprocessor):
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=r'bert-sequence-classification')
|
||||
Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer)
|
||||
class SequenceClassificationPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
@@ -125,7 +126,8 @@ class SequenceClassificationPreprocessor(Preprocessor):
|
||||
return rst
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0')
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer)
|
||||
class TextGenerationPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, tokenizer, *args, **kwargs):
|
||||
@@ -236,7 +238,7 @@ class FillMaskPreprocessor(Preprocessor):
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=r'bert-token-classification')
|
||||
Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer)
|
||||
class TokenClassifcationPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
|
||||
@@ -3,6 +3,7 @@ import io
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.models.audio.tts.frontend import GenericTtsFrontend
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.audio.tts_exceptions import * # noqa F403
|
||||
@@ -10,11 +11,11 @@ from modelscope.utils.constant import Fields
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
|
||||
__all__ = ['TextToTacotronSymbols', 'text_to_tacotron_symbols']
|
||||
__all__ = ['TextToTacotronSymbols']
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.audio, module_name=r'text_to_tacotron_symbols')
|
||||
Fields.audio, module_name=Preprocessors.text_to_tacotron_symbols)
|
||||
class TextToTacotronSymbols(Preprocessor):
|
||||
"""extract tacotron symbols from text.
|
||||
|
||||
|
||||
@@ -1,14 +1,49 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Union
|
||||
|
||||
from maas_hub.constants import MODEL_ID_SEPARATOR
|
||||
from numpy import deprecate
|
||||
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.hub.utils.utils import get_cache_dir
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
|
||||
|
||||
# temp solution before the hub-cache is in place
|
||||
def get_model_cache_dir(model_id: str, branch: str = 'master'):
|
||||
model_id_expanded = model_id.replace('/',
|
||||
MODEL_ID_SEPARATOR) + '.' + branch
|
||||
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
|
||||
return os.getenv('MAAS_CACHE',
|
||||
os.path.join(default_cache_dir, 'hub', model_id_expanded))
|
||||
@deprecate
|
||||
def get_model_cache_dir(model_id: str):
|
||||
return os.path.join(get_cache_dir(), model_id)
|
||||
|
||||
|
||||
def read_config(model_id_or_path: str):
|
||||
""" Read config from hub or local path
|
||||
|
||||
Args:
|
||||
model_id_or_path (str): Model repo name or local directory path.
|
||||
|
||||
Return:
|
||||
config (:obj:`Config`): config object
|
||||
"""
|
||||
if not os.path.exists(model_id_or_path):
|
||||
local_path = model_file_download(model_id_or_path,
|
||||
ModelFile.CONFIGURATION)
|
||||
else:
|
||||
local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
|
||||
|
||||
return Config.from_file(local_path)
|
||||
|
||||
|
||||
def auto_load(model: Union[str, List[str]]):
|
||||
if isinstance(model, str):
|
||||
if not osp.exists(model):
|
||||
model = snapshot_download(model)
|
||||
else:
|
||||
model = [
|
||||
snapshot_download(m) if not osp.exists(m) else m for m in model
|
||||
]
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#tts
|
||||
h5py==2.10.0
|
||||
#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl; python_version=='3.6'
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl; python_version=='3.7'
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl; python_version=='3.8'
|
||||
https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl; python_version=='3.9'
|
||||
https://swap.oss-cn-hangzhou.aliyuncs.com/Jiaqi%2Fmaas%2Ftts%2Frequirements%2Fpytorch_wavelets-1.3.0-py3-none-any.whl?Expires=1685688388&OSSAccessKeyId=LTAI4Ffebq4d9jTVDwiSbY4L&Signature=jcQbg5EZ%2Bdys3%2F4BRn3srrKLdIg%3D
|
||||
#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl
|
||||
#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl
|
||||
inflect
|
||||
keras==2.2.4
|
||||
librosa
|
||||
@@ -12,7 +12,7 @@ lxml
|
||||
matplotlib
|
||||
nara_wpe
|
||||
numpy==1.18.*
|
||||
protobuf==3.20.*
|
||||
protobuf>3,<=3.20
|
||||
ptflops
|
||||
PyWavelets>=1.0.0
|
||||
scikit-learn==0.23.2
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
addict
|
||||
datasets
|
||||
easydict
|
||||
https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl
|
||||
filelock>=3.3.0
|
||||
numpy
|
||||
opencv-python-headless
|
||||
Pillow>=6.2.0
|
||||
pyyaml
|
||||
requests
|
||||
requests==2.27.1
|
||||
scipy
|
||||
setuptools==58.0.4
|
||||
tokenizers<=0.10.3
|
||||
tqdm>=4.64.0
|
||||
transformers<=4.16.2
|
||||
yapf
|
||||
|
||||
0
tests/hub/__init__.py
Normal file
0
tests/hub/__init__.py
Normal file
157
tests/hub/test_hub_operation.py
Normal file
157
tests/hub/test_hub_operation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.hub.utils.utils import get_gitlab_domain
|
||||
|
||||
USER_NAME = 'maasadmin'
|
||||
PASSWORD = '12345678'
|
||||
|
||||
model_chinese_name = '达摩卡通化模型'
|
||||
model_org = 'unittest'
|
||||
DEFAULT_GIT_PATH = 'git'
|
||||
|
||||
|
||||
class GitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# TODO make thest git operation to git library after merge code.
|
||||
def run_git_command(git_path, *args) -> subprocess.CompletedProcess:
|
||||
response = subprocess.run([git_path, *args], capture_output=True)
|
||||
try:
|
||||
response.check_returncode()
|
||||
return response.stdout.decode('utf8')
|
||||
except subprocess.CalledProcessError as error:
|
||||
raise GitError(error.stderr.decode('utf8'))
|
||||
|
||||
|
||||
# for public project, token can None, private repo, there must token.
|
||||
def clone(local_dir: str, token: str, url: str):
|
||||
url = url.replace('//', '//oauth2:%s@' % token)
|
||||
clone_args = '-C %s clone %s' % (local_dir, url)
|
||||
clone_args = clone_args.split(' ')
|
||||
stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args)
|
||||
print('stdout: %s' % stdout)
|
||||
|
||||
|
||||
def push(local_dir: str, token: str, url: str):
|
||||
url = url.replace('//', '//oauth2:%s@' % token)
|
||||
push_args = '-C %s push %s' % (local_dir, url)
|
||||
push_args = push_args.split(' ')
|
||||
stdout = run_git_command(DEFAULT_GIT_PATH, *push_args)
|
||||
print('stdout: %s' % stdout)
|
||||
|
||||
|
||||
sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
|
||||
download_model_file_name = 'mnist-12.onnx'
|
||||
|
||||
|
||||
class HubOperationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.old_cwd = os.getcwd()
|
||||
self.api = HubApi()
|
||||
# note this is temporary before official account management is ready
|
||||
self.api.login(USER_NAME, PASSWORD)
|
||||
self.model_name = uuid.uuid4().hex
|
||||
self.model_id = '%s/%s' % (model_org, self.model_name)
|
||||
self.api.create_model(
|
||||
model_id=self.model_id,
|
||||
chinese_name=model_chinese_name,
|
||||
visibility=5, # 1-private, 5-public
|
||||
license='apache-2.0')
|
||||
|
||||
def tearDown(self):
|
||||
os.chdir(self.old_cwd)
|
||||
self.api.delete_model(model_id=self.model_id)
|
||||
|
||||
def test_model_repo_creation(self):
|
||||
# change to proper model names before use
|
||||
try:
|
||||
info = self.api.get_model(model_id=self.model_id)
|
||||
assert info['Name'] == self.model_name
|
||||
except KeyError as ke:
|
||||
if ke.args[0] == 'name':
|
||||
print(f'model {self.model_name} already exists, ignore')
|
||||
else:
|
||||
raise
|
||||
|
||||
# Note that this can be done via git operation once model repo
|
||||
# has been created. Git-Op is the RECOMMENDED model upload approach
|
||||
def test_model_upload(self):
|
||||
url = f'http://{get_gitlab_domain()}/{self.model_id}'
|
||||
print(url)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
os.chdir(temporary_dir)
|
||||
cmd_args = 'clone %s' % url
|
||||
cmd_args = cmd_args.split(' ')
|
||||
out = run_git_command('git', *cmd_args)
|
||||
print(out)
|
||||
repo_dir = os.path.join(temporary_dir, self.model_name)
|
||||
os.chdir(repo_dir)
|
||||
os.system('touch file1')
|
||||
os.system('git add file1')
|
||||
os.system("git commit -m 'Test'")
|
||||
token = ModelScopeConfig.get_token()
|
||||
push(repo_dir, token, url)
|
||||
|
||||
def test_download_single_file(self):
|
||||
url = f'http://{get_gitlab_domain()}/{self.model_id}'
|
||||
print(url)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
os.chdir(temporary_dir)
|
||||
os.system('git clone %s' % url)
|
||||
repo_dir = os.path.join(temporary_dir, self.model_name)
|
||||
os.chdir(repo_dir)
|
||||
os.system('wget %s' % sample_model_url)
|
||||
os.system('git add .')
|
||||
os.system("git commit -m 'Add file'")
|
||||
token = ModelScopeConfig.get_token()
|
||||
push(repo_dir, token, url)
|
||||
assert os.path.exists(
|
||||
os.path.join(temporary_dir, self.model_name,
|
||||
download_model_file_name))
|
||||
downloaded_file = model_file_download(
|
||||
model_id=self.model_id, file_path=download_model_file_name)
|
||||
mdtime1 = os.path.getmtime(downloaded_file)
|
||||
# download again
|
||||
downloaded_file = model_file_download(
|
||||
model_id=self.model_id, file_path=download_model_file_name)
|
||||
mdtime2 = os.path.getmtime(downloaded_file)
|
||||
assert mdtime1 == mdtime2
|
||||
|
||||
def test_snapshot_download(self):
|
||||
url = f'http://{get_gitlab_domain()}/{self.model_id}'
|
||||
print(url)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
os.chdir(temporary_dir)
|
||||
os.system('git clone %s' % url)
|
||||
repo_dir = os.path.join(temporary_dir, self.model_name)
|
||||
os.chdir(repo_dir)
|
||||
os.system('wget %s' % sample_model_url)
|
||||
os.system('git add .')
|
||||
os.system("git commit -m 'Add file'")
|
||||
token = ModelScopeConfig.get_token()
|
||||
push(repo_dir, token, url)
|
||||
snapshot_path = snapshot_download(model_id=self.model_id)
|
||||
downloaded_file_path = os.path.join(snapshot_path,
|
||||
download_model_file_name)
|
||||
assert os.path.exists(downloaded_file_path)
|
||||
mdtime1 = os.path.getmtime(downloaded_file_path)
|
||||
# download again
|
||||
snapshot_path = snapshot_download(model_id=self.model_id)
|
||||
mdtime2 = os.path.getmtime(downloaded_file_path)
|
||||
assert mdtime1 == mdtime2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -10,7 +10,6 @@ from modelscope.fileio import File
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pydatasets import PyDataset
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -18,11 +17,6 @@ class ImageMattingTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_unet_image-matting'
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
@unittest.skip('deprecated, download model from model hub instead')
|
||||
def test_run_with_direct_file_download(self):
|
||||
@@ -66,7 +60,7 @@ class ImageMattingTest(unittest.TestCase):
|
||||
cv2.imwrite('result.png', result['output_png'])
|
||||
print(f'Output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_modelscope_dataset(self):
|
||||
dataset = PyDataset.load('beans', split='train', target='image')
|
||||
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
|
||||
|
||||
@@ -27,7 +27,7 @@ class OCRDetectionTest(unittest.TestCase):
|
||||
print('ocr detection results: ')
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
ocr_detection = pipeline(Tasks.ocr_detection)
|
||||
self.pipeline_inference(ocr_detection, self.test_image)
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SbertForSentenceSimilarity
|
||||
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
|
||||
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -18,13 +16,6 @@ class SentenceSimilarityTest(unittest.TestCase):
|
||||
sentence1 = '今天气温比昨天高么?'
|
||||
sentence2 = '今天湿度比昨天高么?'
|
||||
|
||||
def setUp(self) -> None:
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
|
||||
@@ -3,9 +3,9 @@ import shutil
|
||||
import unittest
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
|
||||
NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav'
|
||||
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav'
|
||||
@@ -30,11 +30,6 @@ class SpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/speech_dfsmn_aec_psm_16k'
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
# A temporary hack to provide c++ lib. Download it first.
|
||||
download(AEC_LIB_URL, AEC_LIB_FILE)
|
||||
|
||||
@@ -48,7 +43,7 @@ class SpeechSignalProcessTest(unittest.TestCase):
|
||||
aec = pipeline(
|
||||
Tasks.speech_signal_process,
|
||||
model=self.model_id,
|
||||
pipeline_name=r'speech_dfsmn_aec_psm_16k')
|
||||
pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k)
|
||||
aec(input, output_path='output.wav')
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from modelscope.pipelines import SequenceClassificationPipeline, pipeline
|
||||
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||
from modelscope.pydatasets import PyDataset
|
||||
from modelscope.utils.constant import Hubs, Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -19,11 +18,6 @@ class SequenceClassificationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/bert-base-sst2'
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
def predict(self, pipeline_ins: SequenceClassificationPipeline):
|
||||
from easynlp.appzoo import load_dataset
|
||||
@@ -44,31 +38,6 @@ class SequenceClassificationTest(unittest.TestCase):
|
||||
break
|
||||
print(r)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
|
||||
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
|
||||
cache_path_str = r'.cache/easynlp/bert-base-sst2.zip'
|
||||
cache_path = Path(cache_path_str)
|
||||
|
||||
if not cache_path.exists():
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_path.touch(exist_ok=True)
|
||||
with cache_path.open('wb') as ofile:
|
||||
ofile.write(File.read(model_url))
|
||||
|
||||
with zipfile.ZipFile(cache_path_str, 'r') as zipf:
|
||||
zipf.extractall(cache_path.parent)
|
||||
path = r'.cache/easynlp/'
|
||||
model = BertForSequenceClassification(path)
|
||||
preprocessor = SequenceClassificationPreprocessor(
|
||||
path, first_sequence='sentence', second_sequence=None)
|
||||
pipeline1 = SequenceClassificationPipeline(model, preprocessor)
|
||||
self.predict(pipeline1)
|
||||
pipeline2 = pipeline(
|
||||
Tasks.text_classification, model=model, preprocessor=preprocessor)
|
||||
print(pipeline2('Hello world!'))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import PalmForTextGeneration
|
||||
from modelscope.pipelines import TextGenerationPipeline, pipeline
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines, Preprocessors
|
||||
from modelscope.models import Model, build_model
|
||||
from modelscope.models.audio.tts.am import SambertNetHifi16k
|
||||
from modelscope.models.audio.tts.vocoder import AttrDict, Hifigan16k
|
||||
@@ -32,7 +33,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
|
||||
voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo'
|
||||
|
||||
cfg_preprocessor = dict(
|
||||
type='text_to_tacotron_symbols',
|
||||
type=Preprocessors.text_to_tacotron_symbols,
|
||||
model_name=preprocessor_model_id,
|
||||
lang_type=lang_type)
|
||||
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
|
||||
@@ -45,7 +46,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
|
||||
self.assertTrue(voc is not None)
|
||||
|
||||
sambert_tts = pipeline(
|
||||
pipeline_name='tts-sambert-hifigan-16k',
|
||||
pipeline_name=Pipelines.sambert_hifigan_16k_tts,
|
||||
config_file='',
|
||||
model=[am, voc],
|
||||
preprocessor=preprocessor)
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import StructBertForTokenClassification
|
||||
from modelscope.pipelines import WordSegmentationPipeline, pipeline
|
||||
from modelscope.preprocessors import TokenClassifcationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -17,13 +15,6 @@ class WordSegmentationTest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base'
|
||||
sentence = '今天天气不错,适合出去游玩'
|
||||
|
||||
def setUp(self) -> None:
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors import build_preprocessor
|
||||
from modelscope.utils.constant import Fields, InputFields
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -14,7 +15,7 @@ class TtsPreprocessorTest(unittest.TestCase):
|
||||
lang_type = 'pinyin'
|
||||
text = '今天天气不错,我们去散步吧。'
|
||||
cfg = dict(
|
||||
type='text_to_tacotron_symbols',
|
||||
type=Preprocessors.text_to_tacotron_symbols,
|
||||
model_name='damo/speech_binary_tts_frontend_resource',
|
||||
lang_type=lang_type)
|
||||
preprocessor = build_preprocessor(cfg, Fields.audio)
|
||||
|
||||
@@ -33,6 +33,8 @@ class ImgPreprocessor(Preprocessor):
|
||||
|
||||
class PyDatasetTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2,
|
||||
'skip test due to dataset api problem')
|
||||
def test_ds_basic(self):
|
||||
ms_ds_full = PyDataset.load('squad')
|
||||
ms_ds_full_hf = hfdata.load_dataset('squad')
|
||||
|
||||
@@ -61,7 +61,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--test_dir', default='tests', help='directory to be tested')
|
||||
parser.add_argument(
|
||||
'--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0')
|
||||
'--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0')
|
||||
args = parser.parse_args()
|
||||
set_test_level(args.level)
|
||||
logger.info(f'TEST LEVEL: {test_level()}')
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
from maas_hub.maas_api import MaasApi
|
||||
from maas_hub.repository import Repository
|
||||
|
||||
USER_NAME = 'maasadmin'
|
||||
PASSWORD = '12345678'
|
||||
|
||||
|
||||
class HubOperationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.api = MaasApi()
|
||||
# note this is temporary before official account management is ready
|
||||
self.api.login(USER_NAME, PASSWORD)
|
||||
|
||||
@unittest.skip('to be used for local test only')
|
||||
def test_model_repo_creation(self):
|
||||
# change to proper model names before use
|
||||
model_name = 'cv_unet_person-image-cartoon_compound-models'
|
||||
model_chinese_name = '达摩卡通化模型'
|
||||
model_org = 'damo'
|
||||
try:
|
||||
self.api.create_model(
|
||||
owner=model_org,
|
||||
name=model_name,
|
||||
chinese_name=model_chinese_name,
|
||||
visibility=5, # 1-private, 5-public
|
||||
license='apache-2.0')
|
||||
# TODO: support proper name duplication checking
|
||||
except KeyError as ke:
|
||||
if ke.args[0] == 'name':
|
||||
print(f'model {self.model_name} already exists, ignore')
|
||||
else:
|
||||
raise
|
||||
|
||||
# Note that this can be done via git operation once model repo
|
||||
# has been created. Git-Op is the RECOMMENDED model upload approach
|
||||
@unittest.skip('to be used for local test only')
|
||||
def test_model_upload(self):
|
||||
local_path = '/path/to/local/model/directory'
|
||||
assert osp.exists(local_path), 'Local model directory not exist.'
|
||||
repo = Repository(local_dir=local_path)
|
||||
repo.push_to_hub(commit_message='Upload model files')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user