mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
merge with zero_shot_classfication
This commit is contained in:
0
modelscope/hub/__init__.py
Normal file
0
modelscope/hub/__init__.py
Normal file
265
modelscope/hub/api.py
Normal file
265
modelscope/hub/api.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import imp
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
from http.cookiejar import CookieJar
|
||||
from os.path import expanduser
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import LOGGER_NAME
|
||||
from .errors import NotExistError, is_ok, raise_on_error
|
||||
from .utils.utils import get_endpoint, model_id_to_group_owner_name
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class HubApi:
|
||||
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
|
||||
def login(
|
||||
self,
|
||||
user_name: str,
|
||||
password: str,
|
||||
) -> tuple():
|
||||
"""
|
||||
Login with username and password
|
||||
|
||||
Args:
|
||||
username(`str`): user name on modelscope
|
||||
password(`str`): password
|
||||
|
||||
Returns:
|
||||
cookies: to authenticate yourself to ModelScope open-api
|
||||
gitlab token: to access private repos
|
||||
|
||||
<Tip>
|
||||
You only have to login once within 30 days.
|
||||
</Tip>
|
||||
|
||||
TODO: handle cookies expire
|
||||
|
||||
"""
|
||||
path = f'{self.endpoint}/api/v1/login'
|
||||
r = requests.post(
|
||||
path, json={
|
||||
'username': user_name,
|
||||
'password': password
|
||||
})
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
|
||||
token = d['Data']['AccessToken']
|
||||
cookies = r.cookies
|
||||
|
||||
# save token and cookie
|
||||
ModelScopeConfig.save_token(token)
|
||||
ModelScopeConfig.save_cookies(cookies)
|
||||
ModelScopeConfig.write_to_git_credential(user_name, password)
|
||||
|
||||
return d['Data']['AccessToken'], cookies
|
||||
|
||||
def create_model(self, model_id: str, chinese_name: str, visibility: int,
|
||||
license: str) -> str:
|
||||
"""
|
||||
Create model repo at ModelScopeHub
|
||||
|
||||
Args:
|
||||
model_id:(`str`): The model id
|
||||
chinese_name(`str`): chinese name of the model
|
||||
visibility(`int`): visibility of the model(1-private, 3-internal, 5-public)
|
||||
license(`str`): license of the model, candidates can be found at: TBA
|
||||
|
||||
Returns:
|
||||
name of the model created
|
||||
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models'
|
||||
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
||||
r = requests.post(
|
||||
path,
|
||||
json={
|
||||
'Path': owner_or_group,
|
||||
'Name': name,
|
||||
'ChineseName': chinese_name,
|
||||
'Visibility': visibility,
|
||||
'License': license
|
||||
},
|
||||
cookies=cookies)
|
||||
r.raise_for_status()
|
||||
raise_on_error(r.json())
|
||||
d = r.json()
|
||||
return d['Data']['Name']
|
||||
|
||||
def delete_model(self, model_id):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
model_id (str): The model id.
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}'
|
||||
|
||||
r = requests.delete(path, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
raise_on_error(r.json())
|
||||
|
||||
def get_model_url(self, model_id):
|
||||
return f'{self.endpoint}/api/v1/models/{model_id}.git'
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: str = 'master',
|
||||
) -> str:
|
||||
"""
|
||||
Get model information at modelscope_hub
|
||||
|
||||
Args:
|
||||
model_id(`str`): The model id.
|
||||
revision(`str`): revision of model
|
||||
Returns:
|
||||
The model details information.
|
||||
Raises:
|
||||
NotExistError: If the model is not exist, will throw NotExistError
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
||||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?{revision}'
|
||||
|
||||
r = requests.get(path, cookies=cookies)
|
||||
if r.status_code == 200:
|
||||
if is_ok(r.json()):
|
||||
return r.json()['Data']
|
||||
else:
|
||||
raise NotExistError(r.json()['Message'])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
def get_model_branches_and_tags(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
|
||||
r = requests.get(path, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
info = d['Data']
|
||||
branches = [x['Revision'] for x in info['RevisionMap']['Branches']
|
||||
] if info['RevisionMap']['Branches'] else []
|
||||
tags = [x['Revision'] for x in info['RevisionMap']['Tags']
|
||||
] if info['RevisionMap']['Tags'] else []
|
||||
return branches, tags
|
||||
|
||||
def get_model_files(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = 'master',
|
||||
root: Optional[str] = None,
|
||||
recursive: Optional[str] = False,
|
||||
use_cookies: Union[bool, CookieJar] = False) -> List[dict]:
|
||||
|
||||
cookies = None
|
||||
if isinstance(use_cookies, CookieJar):
|
||||
cookies = use_cookies
|
||||
elif use_cookies:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}'
|
||||
if root is not None:
|
||||
path = path + f'&Root={root}'
|
||||
|
||||
r = requests.get(path, cookies=cookies)
|
||||
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
|
||||
files = []
|
||||
for file in d['Data']['Files']:
|
||||
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
|
||||
continue
|
||||
|
||||
files.append(file)
|
||||
return files
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser('~/.modelscope/credentials')
|
||||
os.makedirs(path_credential, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def save_cookies(cls, cookies: CookieJar):
|
||||
with open(os.path.join(cls.path_credential, 'cookies'), 'wb+') as f:
|
||||
pickle.dump(cookies, f)
|
||||
|
||||
@classmethod
|
||||
def get_cookies(cls):
|
||||
try:
|
||||
with open(os.path.join(cls.path_credential, 'cookies'), 'rb') as f:
|
||||
return pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
logger.warn("Auth token does not exist, you'll get authentication \
|
||||
error when downloading private model files. Please login first"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_token(cls, token: str):
|
||||
with open(os.path.join(cls.path_credential, 'token'), 'w+') as f:
|
||||
f.write(token)
|
||||
|
||||
@classmethod
|
||||
def get_token(cls) -> Optional[str]:
|
||||
"""
|
||||
Get token or None if not existent.
|
||||
|
||||
Returns:
|
||||
`str` or `None`: The token, `None` if it doesn't exist.
|
||||
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
with open(os.path.join(cls.path_credential, 'token'), 'r') as f:
|
||||
token = f.read()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def write_to_git_credential(username: str, password: str):
|
||||
with subprocess.Popen(
|
||||
'git credential-store store'.split(),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as process:
|
||||
input_username = f'username={username.lower()}'
|
||||
input_password = f'password={password}'
|
||||
|
||||
process.stdin.write(
|
||||
f'url={get_endpoint()}\n{input_username}\n{input_password}\n\n'
|
||||
.encode('utf-8'))
|
||||
process.stdin.flush()
|
||||
8
modelscope/hub/constants.py
Normal file
8
modelscope/hub/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
MODELSCOPE_URL_SCHEME = 'http://'
|
||||
DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330'
|
||||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102'
|
||||
|
||||
DEFAULT_MODELSCOPE_GROUP = 'damo'
|
||||
MODEL_ID_SEPARATOR = '/'
|
||||
|
||||
LOGGER_NAME = 'ModelScopeHub'
|
||||
30
modelscope/hub/errors.py
Normal file
30
modelscope/hub/errors.py
Normal file
@@ -0,0 +1,30 @@
|
||||
class NotExistError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def is_ok(rsp):
|
||||
""" Check the request is ok
|
||||
|
||||
Args:
|
||||
rsp (_type_): The request response body
|
||||
Failed: {'Code': 10010101004, 'Message': 'get model info failed, err: unauthorized permission',
|
||||
'RequestId': '', 'Success': False}
|
||||
Success: {'Code': 200, 'Data': {}, 'Message': 'success', 'RequestId': '', 'Success': True}
|
||||
"""
|
||||
return rsp['Code'] == 200 and rsp['Success']
|
||||
|
||||
|
||||
def raise_on_error(rsp):
|
||||
"""If response error, raise exception
|
||||
|
||||
Args:
|
||||
rsp (_type_): The server response
|
||||
"""
|
||||
if rsp['Code'] == 200 and rsp['Success']:
|
||||
return True
|
||||
else:
|
||||
raise RequestError(rsp['Message'])
|
||||
254
modelscope/hub/file_download.py
Normal file
254
modelscope/hub/file_download.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import copy
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import json
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from requests.exceptions import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .constants import (DEFAULT_MODELSCOPE_GROUP, LOGGER_NAME,
|
||||
MODEL_ID_SEPARATOR)
|
||||
from .errors import NotExistError, RequestError, raise_on_error
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (get_cache_dir, get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
|
||||
SESSION_ID = uuid4().hex
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def model_file_download(
|
||||
model_id: str,
|
||||
file_path: str,
|
||||
revision: Optional[str] = 'master',
|
||||
cache_dir: Optional[str] = None,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
"""
|
||||
Download from a given URL and cache it if it's not already present in the
|
||||
local cache.
|
||||
|
||||
Given a URL, this function looks for the corresponding file in the local
|
||||
cache. If it's not there, download it. Then return the path to the cached
|
||||
file.
|
||||
|
||||
Args:
|
||||
model_id (`str`):
|
||||
The model to whom the file to be downloaded belongs.
|
||||
file_path(`str`):
|
||||
Path of the file to be downloaded, relative to the root of model repo
|
||||
revision(`str`, *optional*):
|
||||
revision of the model file to be downloaded.
|
||||
Can be any of a branch, tag or commit hash, default to `master`
|
||||
cache_dir (`str`, `Path`, *optional*):
|
||||
Path to the folder where cached files are stored.
|
||||
user_agent (`dict`, `str`, *optional*):
|
||||
The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
if `False`, download the file anyway even it exists
|
||||
|
||||
Returns:
|
||||
Local path (string) of file or if networking is off, last version of
|
||||
file cached on disk.
|
||||
|
||||
<Tip>
|
||||
|
||||
Raises the following errors:
|
||||
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
|
||||
if ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
|
||||
# if local_files_only is `True` and the file already exists in cached_path
|
||||
# return the cached path
|
||||
if local_files_only:
|
||||
cached_file_path = cache.get_file_by_path(file_path)
|
||||
if cached_file_path is not None:
|
||||
logger.warning(
|
||||
"File exists in local cache, but we're not sure it's up to date"
|
||||
)
|
||||
return cached_file_path
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
|
||||
_api = HubApi()
|
||||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
|
||||
branches, tags = _api.get_model_branches_and_tags(model_id)
|
||||
file_to_download_info = None
|
||||
is_commit_id = False
|
||||
if revision in branches or revision in tags: # The revision is version or tag,
|
||||
# we need to confirm the version is up to date
|
||||
# we need to get the file list to check if the lateast version is cached, if so return, otherwise download
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if model_file['Path'] == file_path:
|
||||
model_file['Branch'] = revision
|
||||
if cache.exists(model_file):
|
||||
return cache.get_file_by_info(model_file)
|
||||
else:
|
||||
file_to_download_info = model_file
|
||||
|
||||
if file_to_download_info is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
(file_path, model_id))
|
||||
else: # the revision is commit id.
|
||||
cached_file_path = cache.get_file_by_path_and_commit_id(
|
||||
file_path, revision)
|
||||
if cached_file_path is not None:
|
||||
logger.info('The specified file is in cache, skip downloading!')
|
||||
return cached_file_path # the file is in cache.
|
||||
is_commit_id = True
|
||||
# we need to download again
|
||||
# TODO: skip using JWT for authorization, use cookie instead
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
url_to_download = get_file_download_url(model_id, file_path, revision)
|
||||
file_to_download_info = {
|
||||
'Path': file_path,
|
||||
'Revision':
|
||||
revision if is_commit_id else file_to_download_info['Revision']
|
||||
}
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache.get_root_location() + '.lock'
|
||||
|
||||
with FileLock(lock_path):
|
||||
temp_file_name = next(tempfile._get_candidate_names())
|
||||
http_get_file(
|
||||
url_to_download,
|
||||
cache_dir,
|
||||
temp_file_name,
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
return cache.put_file(file_to_download_info,
|
||||
os.path.join(cache_dir, temp_file_name))
|
||||
|
||||
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||
"""Formats a user-agent string with basic info about a request.
|
||||
|
||||
Args:
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string.
|
||||
|
||||
Returns:
|
||||
The formatted user-agent string.
|
||||
"""
|
||||
ua = f'modelscope/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}'
|
||||
|
||||
if isinstance(user_agent, dict):
|
||||
ua = '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua = user_agent
|
||||
return ua
|
||||
|
||||
|
||||
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
"""
|
||||
Format file download url according to `model_id`, `revision` and `file_path`.
|
||||
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
||||
the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
||||
"""
|
||||
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
||||
return download_url_template.format(
|
||||
endpoint=get_endpoint(),
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
file_path=file_path,
|
||||
)
|
||||
|
||||
|
||||
def http_get_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
cookies: Dict[str, str],
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
This method is only used by snapshot_download, since the behavior is quite different with single file download
|
||||
TODO: consolidate with http_get_file() to avoild duplicate code
|
||||
|
||||
Args:
|
||||
url(`str`):
|
||||
actual download url of the file
|
||||
local_dir(`str`):
|
||||
local directory where the downloaded file stores
|
||||
file_name(`str`):
|
||||
name of the file stored in `local_dir`
|
||||
cookies(`Dict[str, str]`):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(`Optional[Dict[str, str]] = None`):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
|
||||
"""
|
||||
temp_file_manager = partial(
|
||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
||||
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info('downloading %s to %s', url, temp_file.name)
|
||||
headers = copy.deepcopy(headers)
|
||||
|
||||
r = requests.get(url, stream=True, headers=headers, cookies=cookies)
|
||||
r.raise_for_status()
|
||||
|
||||
content_length = r.headers.get('Content-Length')
|
||||
total = int(content_length) if content_length is not None else None
|
||||
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
logger.info('storing %s in cache at %s', url, local_dir)
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
82
modelscope/hub/git.py
Normal file
82
modelscope/hub/git.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from threading import local
|
||||
from tkinter.messagebox import NO
|
||||
from typing import Union
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .constants import LOGGER_NAME
|
||||
from .utils._subprocess import run_subprocess
|
||||
|
||||
logger = get_logger
|
||||
|
||||
|
||||
def git_clone(
|
||||
local_dir: str,
|
||||
repo_url: str,
|
||||
):
|
||||
# TODO: use "git clone" or "git lfs clone" according to git version
|
||||
# TODO: print stderr when subprocess fails
|
||||
run_subprocess(
|
||||
f'git clone {repo_url}'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_checkout(
|
||||
local_dir: str,
|
||||
revsion: str,
|
||||
):
|
||||
run_subprocess(f'git checkout {revsion}'.split(), local_dir)
|
||||
|
||||
|
||||
def git_add(local_dir: str, ):
|
||||
run_subprocess(
|
||||
'git add .'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_commit(local_dir: str, commit_message: str):
|
||||
run_subprocess(
|
||||
'git commit -v -m'.split() + [commit_message],
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_push(local_dir: str, branch: str):
|
||||
# check current branch
|
||||
cur_branch = git_current_branch(local_dir)
|
||||
if cur_branch != branch:
|
||||
logger.error(
|
||||
"You're trying to push to a different branch, please double check")
|
||||
return
|
||||
|
||||
run_subprocess(
|
||||
f'git push origin {branch}'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def git_current_branch(local_dir: str) -> Union[str, None]:
|
||||
"""
|
||||
Get current branch name
|
||||
|
||||
Args:
|
||||
local_dir(`str`): local model repo directory
|
||||
|
||||
Returns
|
||||
branch name you're currently on
|
||||
"""
|
||||
try:
|
||||
process = run_subprocess(
|
||||
'git rev-parse --abbrev-ref HEAD'.split(),
|
||||
local_dir,
|
||||
True,
|
||||
)
|
||||
|
||||
return str(process.stdout).strip()
|
||||
except Exception as e:
|
||||
raise e
|
||||
173
modelscope/hub/repository.py
Normal file
173
modelscope/hub/repository.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import ModelScopeConfig
|
||||
from .constants import MODELSCOPE_URL_SCHEME
|
||||
from .git import git_add, git_checkout, git_clone, git_commit, git_push
|
||||
from .utils._subprocess import run_subprocess
|
||||
from .utils.utils import get_gitlab_domain
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Repository:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_dir: str,
|
||||
clone_from: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
private: Optional[bool] = False,
|
||||
revision: Optional[str] = 'master',
|
||||
):
|
||||
"""
|
||||
Instantiate a Repository object by cloning the remote ModelScopeHub repo
|
||||
Args:
|
||||
local_dir(`str`):
|
||||
local directory to store the model files
|
||||
clone_from(`Optional[str] = None`):
|
||||
model id in ModelScope-hub from which git clone
|
||||
You should ignore this parameter when `local_dir` is already a git repo
|
||||
auth_token(`Optional[str]`):
|
||||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
|
||||
as the token is already saved when you login the first time
|
||||
private(`Optional[bool]`):
|
||||
whether the model is private, default to False
|
||||
revision(`Optional[str]`):
|
||||
revision of the model you want to clone from. Can be any of a branch, tag or commit hash
|
||||
"""
|
||||
logger.info('Instantiating Repository object...')
|
||||
|
||||
# Create local directory if not exist
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
self.local_dir = os.path.join(os.getcwd(), local_dir)
|
||||
|
||||
self.private = private
|
||||
|
||||
# Check git and git-lfs installation
|
||||
self.check_git_versions()
|
||||
|
||||
# Retrieve auth token
|
||||
if not private and isinstance(auth_token, str):
|
||||
logger.warning(
|
||||
'cloning a public repo with a token, which will be ignored')
|
||||
self.token = None
|
||||
else:
|
||||
if isinstance(auth_token, str):
|
||||
self.token = auth_token
|
||||
else:
|
||||
self.token = ModelScopeConfig.get_token()
|
||||
|
||||
if self.token is None:
|
||||
raise EnvironmentError(
|
||||
'Token does not exist, the clone will fail for private repo.'
|
||||
'Please login first.')
|
||||
|
||||
# git clone
|
||||
if clone_from is not None:
|
||||
self.model_id = clone_from
|
||||
logger.info('cloning model repo to %s ...', self.local_dir)
|
||||
git_clone(self.local_dir, self.get_repo_url())
|
||||
else:
|
||||
if is_git_repo(self.local_dir):
|
||||
logger.debug('[Repository] is a valid git repo')
|
||||
else:
|
||||
raise ValueError(
|
||||
'If not specifying `clone_from`, you need to pass Repository a'
|
||||
' valid git clone.')
|
||||
|
||||
# git checkout
|
||||
if isinstance(revision, str) and revision != 'master':
|
||||
git_checkout(revision)
|
||||
|
||||
def push_to_hub(self,
|
||||
commit_message: str,
|
||||
revision: Optional[str] = 'master'):
|
||||
"""
|
||||
Push changes changes to hub
|
||||
|
||||
Args:
|
||||
commit_message(`str`):
|
||||
commit message describing the changes, it's mandatory
|
||||
revision(`Optional[str]`):
|
||||
remote branch you want to push to, default to `master`
|
||||
|
||||
<Tip>
|
||||
The function complains when local and remote branch are different, please be careful
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
git_add(self.local_dir)
|
||||
git_commit(self.local_dir, commit_message)
|
||||
|
||||
logger.info('Pushing changes to repo...')
|
||||
git_push(self.local_dir, revision)
|
||||
|
||||
# TODO: if git push fails, how to retry?
|
||||
|
||||
def check_git_versions(self):
|
||||
"""
|
||||
Checks that `git` and `git-lfs` can be run.
|
||||
|
||||
Raises:
|
||||
`EnvironmentError`: if `git` or `git-lfs` are not installed.
|
||||
"""
|
||||
try:
|
||||
git_version = run_subprocess('git --version'.split(),
|
||||
self.local_dir).stdout.strip()
|
||||
except FileNotFoundError:
|
||||
raise EnvironmentError(
|
||||
'Looks like you do not have git installed, please install.')
|
||||
|
||||
try:
|
||||
lfs_version = run_subprocess('git-lfs --version'.split(),
|
||||
self.local_dir).stdout.strip()
|
||||
except FileNotFoundError:
|
||||
raise EnvironmentError(
|
||||
'Looks like you do not have git-lfs installed, please install.'
|
||||
' You can install from https://git-lfs.github.com/.'
|
||||
' Then run `git lfs install` (you only have to do this once).')
|
||||
logger.info(git_version + '\n' + lfs_version)
|
||||
|
||||
def get_repo_url(self) -> str:
|
||||
"""
|
||||
Get repo url to clone, according whether the repo is private or not
|
||||
"""
|
||||
url = None
|
||||
|
||||
if self.private:
|
||||
url = f'{MODELSCOPE_URL_SCHEME}oauth2:{self.token}@{get_gitlab_domain()}/{self.model_id}'
|
||||
else:
|
||||
url = f'{MODELSCOPE_URL_SCHEME}{get_gitlab_domain()}/{self.model_id}'
|
||||
|
||||
if not url:
|
||||
raise ValueError(
|
||||
'Empty repo url, please check clone_from parameter')
|
||||
|
||||
logger.debug('url to clone: %s', str(url))
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def is_git_repo(folder: Union[str, Path]) -> bool:
|
||||
"""
|
||||
Check if the folder is the root or part of a git repository
|
||||
|
||||
Args:
|
||||
folder (`str`):
|
||||
The folder in which to run the command.
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if the repository is part of a repository, `False`
|
||||
otherwise.
|
||||
"""
|
||||
folder_exists = os.path.exists(os.path.join(folder, '.git'))
|
||||
git_branch = subprocess.run(
|
||||
'git branch'.split(),
|
||||
cwd=folder,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
return folder_exists and git_branch.returncode == 0
|
||||
125
modelscope/hub/snapshot_download.py
Normal file
125
modelscope/hub/snapshot_download.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import tempfile
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR
|
||||
from .errors import NotExistError, RequestError, raise_on_error
|
||||
from .file_download import (get_file_download_url, http_get_file,
|
||||
http_user_agent)
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import get_cache_dir, model_id_to_group_owner_name
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def snapshot_download(model_id: str,
|
||||
revision: Optional[str] = 'master',
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
private: Optional[bool] = False) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
is useful when you want all files from a repo, because you don't know which
|
||||
ones you will need a priori. All files are nested inside a folder in order
|
||||
to keep their actual filename relative to that folder.
|
||||
|
||||
An alternative would be to just clone a repo but this would require that the
|
||||
user always has git and git-lfs installed, and properly configured.
|
||||
Args:
|
||||
model_id (`str`):
|
||||
A user or an organization name and a repo name separated by a `/`.
|
||||
revision (`str`, *optional*):
|
||||
An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (`str`, `Path`, *optional*):
|
||||
Path to the folder where cached files are stored.
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
Returns:
|
||||
Local folder path (string) of repo snapshot
|
||||
|
||||
<Tip>
|
||||
Raises the following errors:
|
||||
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
|
||||
if `use_auth_token=True` and the token cannot be found.
|
||||
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
|
||||
ETag cannot be determined.
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
group_or_owner, name = model_id_to_group_owner_name(model_id)
|
||||
|
||||
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
|
||||
if local_files_only:
|
||||
if len(cache.cached_files) == 0:
|
||||
raise ValueError(
|
||||
'Cannot find the requested files in the cached path and outgoing'
|
||||
' traffic has been disabled. To enable model look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
logger.warn('We can not confirm the cached file is for revision: %s'
|
||||
% revision)
|
||||
return cache.get_root_location(
|
||||
) # we can not confirm the cached file is for snapshot 'revision'
|
||||
else:
|
||||
# make headers
|
||||
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
|
||||
_api = HubApi()
|
||||
# get file list from model repo
|
||||
branches, tags = _api.get_model_branches_and_tags(model_id)
|
||||
if revision not in branches and revision not in tags:
|
||||
raise NotExistError('The specified branch or tag : %s not exist!'
|
||||
% revision)
|
||||
|
||||
model_files = _api.get_model_files(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=private)
|
||||
|
||||
cookies = None
|
||||
if private:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
# check model_file is exist in cache, if exist, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
logger.info(
|
||||
'The specified file is in cache, skip downloading!')
|
||||
continue
|
||||
|
||||
# get download url
|
||||
url = get_file_download_url(
|
||||
model_id=model_id,
|
||||
file_path=model_file['Path'],
|
||||
revision=revision)
|
||||
|
||||
# First download to /tmp
|
||||
http_get_file(
|
||||
url=url,
|
||||
local_dir=tempfile.gettempdir(),
|
||||
file_name=model_file['Name'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict())
|
||||
# put file to cache
|
||||
cache.put_file(
|
||||
model_file,
|
||||
os.path.join(tempfile.gettempdir(), model_file['Name']))
|
||||
|
||||
return os.path.join(cache.get_root_location())
|
||||
0
modelscope/hub/utils/__init__.py
Normal file
0
modelscope/hub/utils/__init__.py
Normal file
40
modelscope/hub/utils/_subprocess.py
Normal file
40
modelscope/hub/utils/_subprocess.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import subprocess
|
||||
from typing import List
|
||||
|
||||
|
||||
def run_subprocess(command: List[str],
|
||||
folder: str,
|
||||
check=True,
|
||||
**kwargs) -> subprocess.CompletedProcess:
|
||||
"""
|
||||
Method to run subprocesses. Calling this will capture the `stderr` and `stdout`,
|
||||
please call `subprocess.run` manually in case you would like for them not to
|
||||
be captured.
|
||||
|
||||
Args:
|
||||
command (`List[str]`):
|
||||
The command to execute as a list of strings.
|
||||
folder (`str`):
|
||||
The folder in which to run the command.
|
||||
check (`bool`, *optional*, defaults to `True`):
|
||||
Setting `check` to `True` will raise a `subprocess.CalledProcessError`
|
||||
when the subprocess has a non-zero exit code.
|
||||
kwargs (`Dict[str]`):
|
||||
Keyword arguments to be passed to the `subprocess.run` underlying command.
|
||||
|
||||
Returns:
|
||||
`subprocess.CompletedProcess`: The completed process.
|
||||
"""
|
||||
if isinstance(command, str):
|
||||
raise ValueError(
|
||||
'`run_subprocess` should be called with a list of strings.')
|
||||
|
||||
return subprocess.run(
|
||||
command,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
check=check,
|
||||
encoding='utf-8',
|
||||
cwd=folder,
|
||||
**kwargs,
|
||||
)
|
||||
294
modelscope/hub/utils/caching.py
Normal file
294
modelscope/hub/utils/caching.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import time
|
||||
from shutil import move, rmtree
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class FileSystemCache(object):
|
||||
KEY_FILE_NAME = '.msc'
|
||||
"""Local file cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_root_location: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
cache_location: str
|
||||
The root location to store files.
|
||||
"""
|
||||
os.makedirs(cache_root_location, exist_ok=True)
|
||||
self.cache_root_location = cache_root_location
|
||||
self.load_cache()
|
||||
|
||||
def get_root_location(self):
|
||||
return self.cache_root_location
|
||||
|
||||
def load_cache(self):
|
||||
"""Read set of stored blocks from file
|
||||
Args:
|
||||
owner(`str`): individual or group username at modelscope, can be empty for official models
|
||||
name(`str`): name of the model
|
||||
Returns:
|
||||
The model details information.
|
||||
Raises:
|
||||
NotExistError: If the model is not exist, will throw NotExistError
|
||||
TODO: Error based error code.
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
self.cached_files = []
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
if os.path.exists(cache_keys_file_path):
|
||||
with open(cache_keys_file_path, 'rb') as f:
|
||||
self.cached_files = pickle.load(f)
|
||||
|
||||
def save_cached_files(self):
|
||||
"""Save cache metadata."""
|
||||
# save new meta to tmp and move to KEY_FILE_NAME
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
# TODO: Sync file write
|
||||
fd, fn = tempfile.mkstemp()
|
||||
with open(fd, 'wb') as f:
|
||||
pickle.dump(self.cached_files, f)
|
||||
move(fn, cache_keys_file_path)
|
||||
|
||||
def get_file(self, key):
|
||||
"""Check the key is in the cache, if exist, return the file, otherwise return None.
|
||||
Args:
|
||||
key(`str`): The cache key.
|
||||
Returns:
|
||||
If file exist, return the cached file location, otherwise None.
|
||||
Raises:
|
||||
None
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
pass
|
||||
|
||||
def put_file(self, key, location):
|
||||
"""Put file to the cache,
|
||||
Args:
|
||||
key(`str`): The cache key
|
||||
location(`str`): Location of the file, we will move the file to cache.
|
||||
Returns:
|
||||
The cached file path of the file.
|
||||
Raises:
|
||||
None
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
pass
|
||||
|
||||
def remove_key(self, key):
|
||||
"""Remove cache key in index, The file is removed manually
|
||||
|
||||
Args:
|
||||
key (dict): The cache key.
|
||||
"""
|
||||
self.cached_files.remove(key)
|
||||
self.save_cached_files()
|
||||
|
||||
def exists(self, key):
|
||||
for cache_file in self.cached_files:
|
||||
if cache_file == key:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def clear_cache(self):
|
||||
"""Remove all files and metadat from the cache
|
||||
|
||||
In the case of multiple cache locations, this clears only the last one,
|
||||
which is assumed to be the read/write one.
|
||||
"""
|
||||
rmtree(self.cache_root_location)
|
||||
self.load_cache()
|
||||
|
||||
def hash_name(self, key):
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
|
||||
class ModelFileSystemCache(FileSystemCache):
|
||||
"""Local cache file layout
|
||||
cache_root/owner/model_name/|individual cached files
|
||||
|.mk: file, The cache index file
|
||||
Save only one version for each file.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_root, owner, name):
|
||||
"""Put file to the cache
|
||||
Args:
|
||||
cache_root(`str`): The modelscope local cache root(default: ~/.modelscope/cache/models/)
|
||||
owner(`str`): The model owner.
|
||||
name('str'): The name of the model
|
||||
branch('str'): The branch of model
|
||||
tag('str'): The tag of model
|
||||
Returns:
|
||||
Raises:
|
||||
None
|
||||
<Tip>
|
||||
model_id = {owner}/{name}
|
||||
</Tip>
|
||||
"""
|
||||
super().__init__(os.path.join(cache_root, owner, name))
|
||||
|
||||
def get_file_by_path(self, file_path):
|
||||
"""Retrieve the cache if there is file match the path.
|
||||
Args:
|
||||
file_path (str): The file path in the model.
|
||||
Returns:
|
||||
path: the full path of the file.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if file_path == cached_file['Path']:
|
||||
cached_file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(cached_file_path):
|
||||
return cached_file_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def get_file_by_path_and_commit_id(self, file_path, commit_id):
|
||||
"""Retrieve the cache if there is file match the path.
|
||||
Args:
|
||||
file_path (str): The file path in the model.
|
||||
commit_id (str): The commit id of the file
|
||||
Returns:
|
||||
path: the full path of the file.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if file_path == cached_file['Path'] and \
|
||||
(cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
|
||||
cached_file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(cached_file_path):
|
||||
return cached_file_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def get_file_by_info(self, model_file_info):
|
||||
"""Check if exist cache file.
|
||||
|
||||
Args:
|
||||
model_file_info (ModelFileInfo): The file information of the file.
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
cache_key = self.__get_cache_key(model_file_info)
|
||||
for cached_file in self.cached_files:
|
||||
if cached_file == cache_key:
|
||||
orig_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(orig_path):
|
||||
return orig_path
|
||||
else:
|
||||
self.remove_key(cached_file)
|
||||
|
||||
return None
|
||||
|
||||
def __get_cache_key(self, model_file_info):
|
||||
cache_key = {
|
||||
'Path': model_file_info['Path'],
|
||||
'Revision': model_file_info['Revision'], # commit id
|
||||
}
|
||||
return cache_key
|
||||
|
||||
def exists(self, model_file_info):
|
||||
"""Check the file is cached or not.
|
||||
|
||||
Args:
|
||||
model_file_info (CachedFileInfo): The cached file info
|
||||
|
||||
Returns:
|
||||
bool: If exists return True otherwise False
|
||||
"""
|
||||
key = self.__get_cache_key(model_file_info)
|
||||
is_exists = False
|
||||
for cached_key in self.cached_files:
|
||||
if cached_key['Path'] == key['Path'] and (
|
||||
cached_key['Revision'].startswith(key['Revision'])
|
||||
or key['Revision'].startswith(cached_key['Revision'])):
|
||||
is_exists = True
|
||||
file_path = os.path.join(self.cache_root_location,
|
||||
model_file_info['Path'])
|
||||
if is_exists:
|
||||
if os.path.exists(file_path):
|
||||
return True
|
||||
else:
|
||||
self.remove_key(
|
||||
model_file_info) # sameone may manual delete the file
|
||||
return False
|
||||
|
||||
def remove_if_exists(self, model_file_info):
|
||||
"""We in cache, remove it.
|
||||
|
||||
Args:
|
||||
model_file_info (ModelFileInfo): The model file information from server.
|
||||
"""
|
||||
for cached_file in self.cached_files:
|
||||
if cached_file['Path'] == model_file_info['Path']:
|
||||
self.remove_key(cached_file)
|
||||
file_path = os.path.join(self.cache_root_location,
|
||||
cached_file['Path'])
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
def put_file(self, model_file_info, model_file_location):
|
||||
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
|
||||
|
||||
Args:
|
||||
model_file_info (str): The file description returned by get_model_files
|
||||
sample:
|
||||
{
|
||||
"CommitMessage": "add model\n",
|
||||
"CommittedDate": 1654857567,
|
||||
"CommitterName": "mulin.lyh",
|
||||
"IsLFS": false,
|
||||
"Mode": "100644",
|
||||
"Name": "resnet18.pth",
|
||||
"Path": "resnet18.pth",
|
||||
"Revision": "09b68012b27de0048ba74003690a890af7aff192",
|
||||
"Size": 46827520,
|
||||
"Type": "blob"
|
||||
}
|
||||
model_file_location (str): The location of the temporary file.
|
||||
Raises:
|
||||
NotImplementedError: _description_
|
||||
|
||||
Returns:
|
||||
str: The location of the cached file.
|
||||
"""
|
||||
self.remove_if_exists(model_file_info) # backup old revision
|
||||
cache_key = self.__get_cache_key(model_file_info)
|
||||
cache_full_path = os.path.join(
|
||||
self.cache_root_location,
|
||||
cache_key['Path']) # Branch and Tag do not have same name.
|
||||
cache_file_dir = os.path.dirname(cache_full_path)
|
||||
if not os.path.exists(cache_file_dir):
|
||||
os.makedirs(cache_file_dir, exist_ok=True)
|
||||
# We can't make operation transaction
|
||||
move(model_file_location, cache_full_path)
|
||||
self.cached_files.append(cache_key)
|
||||
self.save_cached_files()
|
||||
return cache_full_path
|
||||
39
modelscope/hub/utils/utils.py
Normal file
39
modelscope/hub/utils/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
|
||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GROUP,
|
||||
MODEL_ID_SEPARATOR,
|
||||
MODELSCOPE_URL_SCHEME)
|
||||
|
||||
|
||||
def model_id_to_group_owner_name(model_id):
|
||||
if MODEL_ID_SEPARATOR in model_id:
|
||||
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
|
||||
name = model_id.split(MODEL_ID_SEPARATOR)[1]
|
||||
else:
|
||||
group_or_owner = DEFAULT_MODELSCOPE_GROUP
|
||||
name = model_id
|
||||
return group_or_owner, name
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
"""
|
||||
cache dir precedence:
|
||||
function parameter > enviroment > ~/.cache/modelscope/hub
|
||||
"""
|
||||
default_cache_dir = os.path.expanduser(
|
||||
os.path.join('~/.cache', 'modelscope'))
|
||||
return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir,
|
||||
'hub'))
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_DOMAIN)
|
||||
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||
|
||||
|
||||
def get_gitlab_domain():
|
||||
return os.getenv('MODELSCOPE_GITLAB_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_GITLAB_DOMAIN)
|
||||
@@ -4,12 +4,10 @@ import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
|
||||
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
||||
|
||||
@@ -47,9 +45,7 @@ class Model(ABC):
|
||||
if osp.exists(model_name_or_path):
|
||||
local_model_dir = model_name_or_path
|
||||
else:
|
||||
cache_path = get_model_cache_dir(model_name_or_path)
|
||||
local_model_dir = cache_path if osp.exists(
|
||||
cache_path) else snapshot_download(model_name_or_path)
|
||||
local_model_dir = snapshot_download(model_name_or_path)
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# 'Remote model repo {model_name_or_path} does not exists')
|
||||
|
||||
@@ -4,3 +4,4 @@ from .palm_for_text_generation import * # noqa F403
|
||||
from .sbert_for_sentence_similarity import * # noqa F403
|
||||
from .sbert_for_token_classification import * # noqa F403
|
||||
from .sentiment_classification_model import * # noqa F403
|
||||
from .zero_shot_classification_model import * # noqa F403
|
||||
|
||||
47
modelscope/models/nlp/zero_shot_classification_model.py
Normal file
47
modelscope/models/nlp/zero_shot_classification_model.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Model
|
||||
from ..builder import MODELS
|
||||
|
||||
__all__ = ['BertForZeroShotClassification']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.zero_shot_classification,
|
||||
module_name=r'bert-zero-shot-classification')
|
||||
class BertForZeroShotClassification(Model):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the zero shot classification model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
from sofa import SbertForSequenceClassification
|
||||
self.model = SbertForSequenceClassification.from_pretrained(model_dir)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
input (Dict[str, Any]): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: results
|
||||
Example:
|
||||
{
|
||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
|
||||
}
|
||||
"""
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**input)
|
||||
logits = outputs['logits'].numpy()
|
||||
res = {'logits': logits}
|
||||
return res
|
||||
@@ -4,13 +4,11 @@ import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.pydatasets import PyDataset
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .outputs import TASK_OUTPUTS
|
||||
from .util import is_model_name
|
||||
@@ -32,9 +30,7 @@ class Pipeline(ABC):
|
||||
# TODO @wenmeng.zwm replace model.startswith('damo/') with get_model
|
||||
if isinstance(model, str) and model.startswith('damo/'):
|
||||
if not osp.exists(model):
|
||||
cache_path = get_model_cache_dir(model)
|
||||
model = cache_path if osp.exists(
|
||||
cache_path) else snapshot_download(model)
|
||||
model = snapshot_download(model)
|
||||
return Model.from_pretrained(model) if is_model_name(
|
||||
model) else model
|
||||
elif isinstance(model, Model):
|
||||
@@ -80,7 +76,7 @@ class Pipeline(ABC):
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def __call__(self, input: Union[Input, List[Input]], *args,
|
||||
**post_kwargs) -> Union[Dict[str, Any], Generator]:
|
||||
**kwargs) -> Union[Dict[str, Any], Generator]:
|
||||
# model provider should leave it as it is
|
||||
# modelscope library developer will handle this function
|
||||
|
||||
@@ -89,24 +85,41 @@ class Pipeline(ABC):
|
||||
if isinstance(input, list):
|
||||
output = []
|
||||
for ele in input:
|
||||
output.append(self._process_single(ele, *args, **post_kwargs))
|
||||
output.append(self._process_single(ele, *args, **kwargs))
|
||||
|
||||
elif isinstance(input, PyDataset):
|
||||
return self._process_iterator(input, *args, **post_kwargs)
|
||||
return self._process_iterator(input, *args, **kwargs)
|
||||
|
||||
else:
|
||||
output = self._process_single(input, *args, **post_kwargs)
|
||||
output = self._process_single(input, *args, **kwargs)
|
||||
return output
|
||||
|
||||
def _process_iterator(self, input: Input, *args, **post_kwargs):
|
||||
def _process_iterator(self, input: Input, *args, **kwargs):
|
||||
for ele in input:
|
||||
yield self._process_single(ele, *args, **post_kwargs)
|
||||
yield self._process_single(ele, *args, **kwargs)
|
||||
|
||||
def _process_single(self, input: Input, *args,
|
||||
**post_kwargs) -> Dict[str, Any]:
|
||||
out = self.preprocess(input)
|
||||
out = self.forward(out)
|
||||
out = self.postprocess(out, **post_kwargs)
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
"""
|
||||
this method should sanitize the keyword args to preprocessor params,
|
||||
forward params and postprocess params on '__call__' or '_process_single' method
|
||||
considering to be a normal classmethod with default implementation / output
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: preprocess_params = {}
|
||||
Dict[str, str]: forward_params = {}
|
||||
Dict[str, str]: postprocess_params = pipeline_parameters
|
||||
"""
|
||||
# raise NotImplementedError("_sanitize_parameters not implemented")
|
||||
return {}, {}, pipeline_parameters
|
||||
|
||||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
# sanitize the parameters
|
||||
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
|
||||
**kwargs)
|
||||
out = self.preprocess(input, **preprocess_params)
|
||||
out = self.forward(out, **forward_params)
|
||||
out = self.postprocess(out, **postprocess_params)
|
||||
self._check_output(out)
|
||||
return out
|
||||
|
||||
@@ -126,23 +139,25 @@ class Pipeline(ABC):
|
||||
raise ValueError(f'expected output keys are {output_keys}, '
|
||||
f'those {missing_keys} are missing')
|
||||
|
||||
def preprocess(self, inputs: Input) -> Dict[str, Any]:
|
||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
""" Provide default implementation based on preprocess_cfg and user can reimplement it
|
||||
"""
|
||||
assert self.preprocessor is not None, 'preprocess method should be implemented'
|
||||
assert not isinstance(self.preprocessor, List),\
|
||||
'default implementation does not support using multiple preprocessors.'
|
||||
return self.preprocessor(inputs)
|
||||
return self.preprocessor(inputs, **preprocess_params)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
""" Provide default implementation using self.model and user can reimplement it
|
||||
"""
|
||||
assert self.model is not None, 'forward method should be implemented'
|
||||
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
|
||||
return self.model(inputs)
|
||||
return self.model(inputs, **forward_params)
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**postprocess_params) -> Dict[str, Any]:
|
||||
""" If current pipeline support model reuse, common postprocess
|
||||
code should be write here.
|
||||
|
||||
|
||||
@@ -27,6 +27,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/nlp_structbert_sentiment-classification_chinese-base'),
|
||||
Tasks.text_classification: ('bert-sentiment-analysis',
|
||||
'damo/bert-base-sst2'),
|
||||
Tasks.zero_shot_classification:
|
||||
('bert-zero-shot-classification',
|
||||
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
|
||||
Tasks.text_generation: ('palm2.0',
|
||||
'damo/nlp_palm2.0_text-generation_chinese-base'),
|
||||
Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'),
|
||||
|
||||
@@ -4,3 +4,4 @@ from .sentiment_classification_pipeline import * # noqa F403
|
||||
from .sequence_classification_pipeline import * # noqa F403
|
||||
from .text_generation_pipeline import * # noqa F403
|
||||
from .word_segmentation_pipeline import * # noqa F403
|
||||
from .zero_shot_classification_pipeline import * # noqa F403
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
|
||||
from modelscope.models.nlp import BertForZeroShotClassification
|
||||
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ...models import Model
|
||||
from ..base import Input, Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
__all__ = ['ZeroShotClassificationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.zero_shot_classification,
|
||||
module_name=r'bert-zero-shot-classification')
|
||||
class ZeroShotClassificationPipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[BertForZeroShotClassification, str],
|
||||
preprocessor: ZeroShotClassificationPreprocessor = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
|
||||
|
||||
Args:
|
||||
model (SbertForSentimentClassification): a model instance
|
||||
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
|
||||
"""
|
||||
assert isinstance(model, str) or isinstance(model, BertForZeroShotClassification), \
|
||||
'model must be a single str or BertForZeroShotClassification'
|
||||
sc_model = model if isinstance(
|
||||
model,
|
||||
BertForZeroShotClassification) else Model.from_pretrained(model)
|
||||
|
||||
self.entailment_id = 0
|
||||
self.contradiction_id = 2
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = ZeroShotClassificationPreprocessor(
|
||||
sc_model.model_dir)
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
postprocess_params = {}
|
||||
|
||||
if 'candidate_labels' in kwargs:
|
||||
candidate_labels = kwargs.pop('candidate_labels')
|
||||
preprocess_params['candidate_labels'] = candidate_labels
|
||||
postprocess_params['candidate_labels'] = candidate_labels
|
||||
else:
|
||||
raise ValueError('You must include at least one label.')
|
||||
preprocess_params['hypothesis_template'] = kwargs.pop(
|
||||
'hypothesis_template', '{}')
|
||||
|
||||
postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def postprocess(self,
|
||||
inputs: Dict[str, Any],
|
||||
candidate_labels,
|
||||
multi_label=False) -> Dict[str, Any]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, Any]): _description_
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the prediction results
|
||||
"""
|
||||
|
||||
logits = inputs['logits']
|
||||
if multi_label or len(candidate_labels) == 1:
|
||||
logits = logits[..., [self.contradiction_id, self.entailment_id]]
|
||||
scores = softmax(logits, axis=-1)[..., 1]
|
||||
else:
|
||||
logits = logits[..., self.entailment_id]
|
||||
scores = softmax(logits, axis=-1)
|
||||
|
||||
reversed_index = list(reversed(scores.argsort()))
|
||||
result = {
|
||||
'labels': [candidate_labels[i] for i in reversed_index],
|
||||
'scores': [scores[i].item() for i in reversed_index],
|
||||
}
|
||||
return result
|
||||
@@ -2,8 +2,7 @@
|
||||
import os.path as osp
|
||||
from typing import List, Union
|
||||
|
||||
from maas_hub.file_download import model_file_download
|
||||
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -4,11 +4,10 @@ from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.constant import Fields, ModelFile
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
@@ -34,9 +33,7 @@ class OfaImageCaptionPreprocessor(Preprocessor):
|
||||
if osp.exists(model_dir):
|
||||
local_model_dir = model_dir
|
||||
else:
|
||||
cache_path = get_model_cache_dir(model_dir)
|
||||
local_model_dir = cache_path if osp.exists(
|
||||
cache_path) else snapshot_download(model_dir)
|
||||
local_model_dir = snapshot_download(model_dir)
|
||||
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
bpe_dir = local_model_dir
|
||||
|
||||
|
||||
@@ -12,8 +12,9 @@ from .builder import PREPROCESSORS
|
||||
|
||||
__all__ = [
|
||||
'Tokenize', 'SequenceClassificationPreprocessor',
|
||||
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor',
|
||||
'NLIPreprocessor', 'SentimentClassificationPreprocessor'
|
||||
'TextGenerationPreprocessor', 'ZeroShotClassificationPreprocessor',
|
||||
'TokenClassifcationPreprocessor', 'NLIPreprocessor',
|
||||
'SentimentClassificationPreprocessor'
|
||||
]
|
||||
|
||||
|
||||
@@ -314,6 +315,50 @@ class TextGenerationPreprocessor(Preprocessor):
|
||||
return {k: torch.tensor(v) for k, v in rst.items()}
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=r'bert-zero-shot-classification')
|
||||
class ZeroShotClassificationPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""preprocess the data via the vocab.txt from the `model_dir` path
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
from sofa import SbertTokenizer
|
||||
self.model_dir: str = model_dir
|
||||
self.sequence_length = kwargs.pop('sequence_length', 512)
|
||||
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
|
||||
|
||||
@type_assert(object, str)
|
||||
def __call__(self, data: str, hypothesis_template: str,
|
||||
candidate_labels: list) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
data (str): a sentence
|
||||
Example:
|
||||
'you are so handsome.'
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
pairs = [[data, hypothesis_template.format(label)]
|
||||
for label in candidate_labels]
|
||||
|
||||
features = self.tokenizer(
|
||||
pairs,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.sequence_length,
|
||||
return_tensors='pt',
|
||||
truncation_strategy='only_first')
|
||||
return features
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=r'bert-token-classification')
|
||||
class TokenClassifcationPreprocessor(Preprocessor):
|
||||
@@ -343,6 +388,7 @@ class TokenClassifcationPreprocessor(Preprocessor):
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
|
||||
# preprocess the data for the model input
|
||||
|
||||
text = data.replace(' ', '').strip()
|
||||
|
||||
@@ -31,6 +31,7 @@ class Tasks(object):
|
||||
ocr_detection = 'ocr-detection'
|
||||
|
||||
# nlp tasks
|
||||
zero_shot_classification = 'zero-shot-classification'
|
||||
word_segmentation = 'word-segmentation'
|
||||
nli = 'nli'
|
||||
sentiment_classification = 'sentiment-classification'
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
|
||||
import os
|
||||
|
||||
from maas_hub.constants import MODEL_ID_SEPARATOR
|
||||
from modelscope.hub.constants import MODEL_ID_SEPARATOR
|
||||
from modelscope.hub.utils.utils import get_cache_dir
|
||||
|
||||
|
||||
# temp solution before the hub-cache is in place
|
||||
def get_model_cache_dir(model_id: str, branch: str = 'master'):
|
||||
model_id_expanded = model_id.replace('/',
|
||||
MODEL_ID_SEPARATOR) + '.' + branch
|
||||
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
|
||||
return os.getenv('MAAS_CACHE',
|
||||
os.path.join(default_cache_dir, 'hub', model_id_expanded))
|
||||
def get_model_cache_dir(model_id: str):
|
||||
return os.path.join(get_cache_dir(), model_id)
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
addict
|
||||
datasets
|
||||
easydict
|
||||
https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl
|
||||
filelock>=3.3.0
|
||||
numpy
|
||||
opencv-python-headless
|
||||
Pillow>=6.2.0
|
||||
pyyaml
|
||||
requests
|
||||
requests==2.27.1
|
||||
scipy
|
||||
setuptools==58.0.4
|
||||
tokenizers<=0.10.3
|
||||
tqdm>=4.64.0
|
||||
transformers<=4.16.2
|
||||
yapf
|
||||
|
||||
0
tests/hub/__init__.py
Normal file
0
tests/hub/__init__.py
Normal file
157
tests/hub/test_hub_operation.py
Normal file
157
tests/hub/test_hub_operation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.hub.utils.utils import get_gitlab_domain
|
||||
|
||||
USER_NAME = 'maasadmin'
|
||||
PASSWORD = '12345678'
|
||||
|
||||
model_chinese_name = '达摩卡通化模型'
|
||||
model_org = 'unittest'
|
||||
DEFAULT_GIT_PATH = 'git'
|
||||
|
||||
|
||||
class GitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# TODO make thest git operation to git library after merge code.
|
||||
def run_git_command(git_path, *args) -> subprocess.CompletedProcess:
|
||||
response = subprocess.run([git_path, *args], capture_output=True)
|
||||
try:
|
||||
response.check_returncode()
|
||||
return response.stdout.decode('utf8')
|
||||
except subprocess.CalledProcessError as error:
|
||||
raise GitError(error.stderr.decode('utf8'))
|
||||
|
||||
|
||||
# for public project, token can None, private repo, there must token.
|
||||
def clone(local_dir: str, token: str, url: str):
|
||||
url = url.replace('//', '//oauth2:%s@' % token)
|
||||
clone_args = '-C %s clone %s' % (local_dir, url)
|
||||
clone_args = clone_args.split(' ')
|
||||
stdout = run_git_command(DEFAULT_GIT_PATH, *clone_args)
|
||||
print('stdout: %s' % stdout)
|
||||
|
||||
|
||||
def push(local_dir: str, token: str, url: str):
|
||||
url = url.replace('//', '//oauth2:%s@' % token)
|
||||
push_args = '-C %s push %s' % (local_dir, url)
|
||||
push_args = push_args.split(' ')
|
||||
stdout = run_git_command(DEFAULT_GIT_PATH, *push_args)
|
||||
print('stdout: %s' % stdout)
|
||||
|
||||
|
||||
sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
|
||||
download_model_file_name = 'mnist-12.onnx'
|
||||
|
||||
|
||||
class HubOperationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.old_cwd = os.getcwd()
|
||||
self.api = HubApi()
|
||||
# note this is temporary before official account management is ready
|
||||
self.api.login(USER_NAME, PASSWORD)
|
||||
self.model_name = uuid.uuid4().hex
|
||||
self.model_id = '%s/%s' % (model_org, self.model_name)
|
||||
self.api.create_model(
|
||||
model_id=self.model_id,
|
||||
chinese_name=model_chinese_name,
|
||||
visibility=5, # 1-private, 5-public
|
||||
license='apache-2.0')
|
||||
|
||||
def tearDown(self):
|
||||
os.chdir(self.old_cwd)
|
||||
self.api.delete_model(model_id=self.model_id)
|
||||
|
||||
def test_model_repo_creation(self):
|
||||
# change to proper model names before use
|
||||
try:
|
||||
info = self.api.get_model(model_id=self.model_id)
|
||||
assert info['Name'] == self.model_name
|
||||
except KeyError as ke:
|
||||
if ke.args[0] == 'name':
|
||||
print(f'model {self.model_name} already exists, ignore')
|
||||
else:
|
||||
raise
|
||||
|
||||
# Note that this can be done via git operation once model repo
|
||||
# has been created. Git-Op is the RECOMMENDED model upload approach
|
||||
def test_model_upload(self):
|
||||
url = f'http://{get_gitlab_domain()}/{self.model_id}'
|
||||
print(url)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
os.chdir(temporary_dir)
|
||||
cmd_args = 'clone %s' % url
|
||||
cmd_args = cmd_args.split(' ')
|
||||
out = run_git_command('git', *cmd_args)
|
||||
print(out)
|
||||
repo_dir = os.path.join(temporary_dir, self.model_name)
|
||||
os.chdir(repo_dir)
|
||||
os.system('touch file1')
|
||||
os.system('git add file1')
|
||||
os.system("git commit -m 'Test'")
|
||||
token = ModelScopeConfig.get_token()
|
||||
push(repo_dir, token, url)
|
||||
|
||||
def test_download_single_file(self):
|
||||
url = f'http://{get_gitlab_domain()}/{self.model_id}'
|
||||
print(url)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
os.chdir(temporary_dir)
|
||||
os.system('git clone %s' % url)
|
||||
repo_dir = os.path.join(temporary_dir, self.model_name)
|
||||
os.chdir(repo_dir)
|
||||
os.system('wget %s' % sample_model_url)
|
||||
os.system('git add .')
|
||||
os.system("git commit -m 'Add file'")
|
||||
token = ModelScopeConfig.get_token()
|
||||
push(repo_dir, token, url)
|
||||
assert os.path.exists(
|
||||
os.path.join(temporary_dir, self.model_name,
|
||||
download_model_file_name))
|
||||
downloaded_file = model_file_download(
|
||||
model_id=self.model_id, file_path=download_model_file_name)
|
||||
mdtime1 = os.path.getmtime(downloaded_file)
|
||||
# download again
|
||||
downloaded_file = model_file_download(
|
||||
model_id=self.model_id, file_path=download_model_file_name)
|
||||
mdtime2 = os.path.getmtime(downloaded_file)
|
||||
assert mdtime1 == mdtime2
|
||||
|
||||
def test_snapshot_download(self):
|
||||
url = f'http://{get_gitlab_domain()}/{self.model_id}'
|
||||
print(url)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
os.chdir(temporary_dir)
|
||||
os.system('git clone %s' % url)
|
||||
repo_dir = os.path.join(temporary_dir, self.model_name)
|
||||
os.chdir(repo_dir)
|
||||
os.system('wget %s' % sample_model_url)
|
||||
os.system('git add .')
|
||||
os.system("git commit -m 'Add file'")
|
||||
token = ModelScopeConfig.get_token()
|
||||
push(repo_dir, token, url)
|
||||
snapshot_path = snapshot_download(model_id=self.model_id)
|
||||
downloaded_file_path = os.path.join(snapshot_path,
|
||||
download_model_file_name)
|
||||
assert os.path.exists(downloaded_file_path)
|
||||
mdtime1 = os.path.getmtime(downloaded_file_path)
|
||||
# download again
|
||||
snapshot_path = snapshot_download(model_id=self.model_id)
|
||||
mdtime2 = os.path.getmtime(downloaded_file_path)
|
||||
assert mdtime1 == mdtime2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -10,7 +10,6 @@ from modelscope.fileio import File
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pydatasets import PyDataset
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -18,11 +17,6 @@ class ImageMattingTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_unet_image-matting'
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
@unittest.skip('deprecated, download model from model hub instead')
|
||||
def test_run_with_direct_file_download(self):
|
||||
|
||||
@@ -27,7 +27,7 @@ class OCRDetectionTest(unittest.TestCase):
|
||||
print('ocr detection results: ')
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
ocr_detection = pipeline(Tasks.ocr_detection)
|
||||
self.pipeline_inference(ocr_detection, self.test_image)
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import SbertForSentenceSimilarity
|
||||
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
|
||||
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -18,13 +16,6 @@ class SentenceSimilarityTest(unittest.TestCase):
|
||||
sentence1 = '今天气温比昨天高么?'
|
||||
sentence2 = '今天湿度比昨天高么?'
|
||||
|
||||
def setUp(self) -> None:
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
|
||||
@@ -5,7 +5,6 @@ import unittest
|
||||
from modelscope.fileio import File
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
|
||||
NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav'
|
||||
FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav'
|
||||
@@ -30,11 +29,6 @@ class SpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/speech_dfsmn_aec_psm_16k'
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
# A temporary hack to provide c++ lib. Download it first.
|
||||
download(AEC_LIB_URL, AEC_LIB_FILE)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from modelscope.pipelines import SequenceClassificationPipeline, pipeline
|
||||
from modelscope.preprocessors import SequenceClassificationPreprocessor
|
||||
from modelscope.pydatasets import PyDataset
|
||||
from modelscope.utils.constant import Hubs, Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -19,11 +18,6 @@ class SequenceClassificationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/bert-base-sst2'
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
def predict(self, pipeline_ins: SequenceClassificationPipeline):
|
||||
from easynlp.appzoo import load_dataset
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import PalmForTextGeneration
|
||||
from modelscope.pipelines import TextGenerationPipeline, pipeline
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import StructBertForTokenClassification
|
||||
from modelscope.pipelines import WordSegmentationPipeline, pipeline
|
||||
from modelscope.preprocessors import TokenClassifcationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.hub import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -17,13 +15,6 @@ class WordSegmentationTest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base'
|
||||
sentence = '今天天气不错,适合出去游玩'
|
||||
|
||||
def setUp(self) -> None:
|
||||
# switch to False if downloading everytime is not desired
|
||||
purge_cache = True
|
||||
if purge_cache:
|
||||
shutil.rmtree(
|
||||
get_model_cache_dir(self.model_id), ignore_errors=True)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
|
||||
60
tests/pipelines/test_zero_shot_classification.py
Normal file
60
tests/pipelines/test_zero_shot_classification.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from maas_hub.snapshot_download import snapshot_download
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import BertForZeroShotClassification
|
||||
from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline
|
||||
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class ZeroShotClassificationTest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
|
||||
sentence = '全新突破 解放军运20版空中加油机曝光'
|
||||
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
|
||||
template = '这篇文章的标题是{}'
|
||||
|
||||
def test_run_from_local(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
tokenizer = ZeroShotClassificationPreprocessor(cache_path)
|
||||
model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer)
|
||||
pipeline1 = ZeroShotClassificationPipeline(
|
||||
model, preprocessor=tokenizer)
|
||||
pipeline2 = pipeline(
|
||||
Tasks.zero_shot_classification,
|
||||
model=model,
|
||||
preprocessor=tokenizer)
|
||||
|
||||
print(
|
||||
f'sentence: {self.sentence}\n'
|
||||
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
|
||||
)
|
||||
print()
|
||||
print(
|
||||
f'sentence: {self.sentence}\n'
|
||||
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'
|
||||
)
|
||||
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.zero_shot_classification,
|
||||
model=model,
|
||||
preprocessor=tokenizer)
|
||||
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
|
||||
|
||||
def test_run_with_model_name(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.zero_shot_classification, model=self.model_id)
|
||||
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
|
||||
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(task=Tasks.zero_shot_classification)
|
||||
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -61,7 +61,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--test_dir', default='tests', help='directory to be tested')
|
||||
parser.add_argument(
|
||||
'--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0')
|
||||
'--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0')
|
||||
args = parser.parse_args()
|
||||
set_test_level(args.level)
|
||||
logger.info(f'TEST LEVEL: {test_level()}')
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
from maas_hub.maas_api import MaasApi
|
||||
from maas_hub.repository import Repository
|
||||
|
||||
USER_NAME = 'maasadmin'
|
||||
PASSWORD = '12345678'
|
||||
|
||||
|
||||
class HubOperationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.api = MaasApi()
|
||||
# note this is temporary before official account management is ready
|
||||
self.api.login(USER_NAME, PASSWORD)
|
||||
|
||||
@unittest.skip('to be used for local test only')
|
||||
def test_model_repo_creation(self):
|
||||
# change to proper model names before use
|
||||
model_name = 'cv_unet_person-image-cartoon_compound-models'
|
||||
model_chinese_name = '达摩卡通化模型'
|
||||
model_org = 'damo'
|
||||
try:
|
||||
self.api.create_model(
|
||||
owner=model_org,
|
||||
name=model_name,
|
||||
chinese_name=model_chinese_name,
|
||||
visibility=5, # 1-private, 5-public
|
||||
license='apache-2.0')
|
||||
# TODO: support proper name duplication checking
|
||||
except KeyError as ke:
|
||||
if ke.args[0] == 'name':
|
||||
print(f'model {self.model_name} already exists, ignore')
|
||||
else:
|
||||
raise
|
||||
|
||||
# Note that this can be done via git operation once model repo
|
||||
# has been created. Git-Op is the RECOMMENDED model upload approach
|
||||
@unittest.skip('to be used for local test only')
|
||||
def test_model_upload(self):
|
||||
local_path = '/path/to/local/model/directory'
|
||||
assert osp.exists(local_path), 'Local model directory not exist.'
|
||||
repo = Repository(local_dir=local_path)
|
||||
repo.push_to_hub(commit_message='Upload model files')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user