Merge commit '58c90dee2c8c2b065baa75923ea6f1cb3964f9cb' into feat/whoami

* commit '58c90dee2c8c2b065baa75923ea6f1cb3964f9cb': (23 commits)
  fix model directory (#1197)
  Fix typo and unused imports (#1196)
  Fix file exists of sub folder (#1192)
  Fix cookies due to firewall token expired (#1193)
  fix lint (#1188)
  Merge release 1.22 (#1187)
  add file name (#1186)
  Support upload file and folder in the hub api (#1152)
  update doc with llama_index (#1180)
  Fix/text gen (#1177)
  Add repo_id and repo_type in snapshot_download (#1172)
  Unify datasets cache dir (#1178)
  feat: all other ollama models (#1174)
  logger.warning when using remote code (#1171)
  fix path contatenation to be windows compatabile (#1176)
  fix https://www.modelscope.cn/models/iic/nlp_structbert_address-parsing_chinese_base/feedback/issueDetail/20431 (#1170)
  support ms-swift 3.0.0 (#1166)
  support latest datasets version (#1163)
  fix lint (#1168)
  fix check model (#1134)
  ...
This commit is contained in:
yuze.zyz
2025-01-20 20:33:15 +08:00
39 changed files with 2524 additions and 264 deletions

View File

@@ -32,6 +32,7 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
pip install faiss-gpu
pip install healpy
pip install huggingface-hub==0.25.2
pip install ms-swift>=3.0.1
# test with install
pip install .
else

View File

@@ -36,6 +36,8 @@ class Builder:
args.lmdeploy_version = '0.6.2'
if not args.autogptq_version:
args.autogptq_version = '0.7.1'
if not args.flashattn_version:
args.flashattn_version = '2.7.1.post4'
return args
def _generate_cudatoolkit_version(self, cuda_version: str) -> str:
@@ -209,8 +211,8 @@ RUN pip install tf-keras==2.16.0 --no-dependencies && \
version_args = (
f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} '
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version}'
)
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} '
f'{self.args.flashattn_version}')
base_image = (
f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-{self.args.python_tag}-'
f'torch{self.args.torch_version}-tf{self.args.tf_version}-base')
@@ -274,6 +276,8 @@ class LLMImageBuilder(Builder):
args.lmdeploy_version = '0.6.2'
if not args.autogptq_version:
args.autogptq_version = '0.7.1'
if not args.flashattn_version:
args.flashattn_version = '2.7.1.post4'
return args
def generate_dockerfile(self) -> str:
@@ -284,8 +288,8 @@ class LLMImageBuilder(Builder):
self.args.python_version)
version_args = (
f'{self.args.torch_version} {self.args.torchvision_version} {self.args.torchaudio_version} '
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version}'
)
f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version} '
f'{self.args.flashattn_version}')
with open('docker/Dockerfile.ubuntu', 'r') as f:
content = f.read()
content = content.replace('{base_image}', self.args.base_image)
@@ -341,12 +345,12 @@ parser.add_argument('--torchaudio_version', type=str, default=None)
parser.add_argument('--tf_version', type=str, default=None)
parser.add_argument('--vllm_version', type=str, default=None)
parser.add_argument('--lmdeploy_version', type=str, default=None)
parser.add_argument('--flashattn_version', type=str, default=None)
parser.add_argument('--autogptq_version', type=str, default=None)
parser.add_argument('--modelscope_branch', type=str, default='master')
parser.add_argument('--modelscope_version', type=str, default='9.99.0')
parser.add_argument('--swift_branch', type=str, default='main')
parser.add_argument('--dry_run', type=int, default=0)
args = parser.parse_args()
if args.image_type.lower() == 'base_cpu':

View File

@@ -6,6 +6,7 @@ torchaudio_version=${3:-2.4.0}
vllm_version=${4:-0.6.0}
lmdeploy_version=${5:-0.6.1}
autogptq_version=${6:-0.7.1}
flashattn_version=${7:-2.7.1.post4}
pip install --no-cache-dir -U autoawq lmdeploy==$lmdeploy_version
@@ -17,7 +18,8 @@ pip install --no-cache-dir tiktoken transformers_stream_generator bitsandbytes d
# pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
# find on: https://github.com/Dao-AILab/flash-attention/releases
cd /tmp && git clone https://github.com/Dao-AILab/flash-attention.git && cd flash-attention && python setup.py install && cd / && rm -fr /tmp/flash-attention && pip cache purge;
# cd /tmp && git clone https://github.com/Dao-AILab/flash-attention.git && cd flash-attention && python setup.py install && cd / && rm -fr /tmp/flash-attention && pip cache purge;
pip install --no-cache-dir flash_attn==$flashattn_version
pip install --no-cache-dir triton auto-gptq==$autogptq_version vllm==$vllm_version -U && pip cache purge

View File

@@ -24,7 +24,7 @@ options:
Get access token: [我的页面](https://modelscope.cn/my/myaccesstoken)获取**SDK 令牌**
## download model
## download
```bash
modelscope download --help
@@ -36,6 +36,7 @@ modelscope download --help
options:
-h, --help show this help message and exit
--model MODEL The model id to be downloaded.
--dataset DATASET The dataset id to be downloaded.
--revision REVISION Revision of the model.
--cache_dir CACHE_DIR
Cache directory to save model.

File diff suppressed because one or more lines are too long

View File

@@ -2,6 +2,10 @@
"cells": [
{
"cell_type": "markdown",
"id": "f4abc589d9bfffca",
"metadata": {
"collapsed": false
},
"source": [
"# Usage\n",
"\n",
@@ -34,27 +38,29 @@
"```\n",
"\n",
"## 3. Go!"
],
"metadata": {
"collapsed": false
},
"id": "f4abc589d9bfffca"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c32122833dd7b8c8",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"!pip install modelscope\n",
"!pip install transformers -U\n",
"!pip install llama-index llama-index-llms-huggingface ipywidgets "
],
"metadata": {
"collapsed": false
},
"id": "c32122833dd7b8c8"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63704e2b21a9ba52",
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"!wget https://modelscope.oss-cn-beijing.aliyuncs.com/resource/rag/punkt.zip\n",
@@ -74,15 +80,90 @@
"!mv /mnt/workspace/xianjiaoda.md /mnt/workspace/custom_data\n",
"\n",
"!cd /mnt/workspace"
],
"metadata": {
"collapsed": false
},
"id": "63704e2b21a9ba52"
]
},
{
"cell_type": "code",
"outputs": [],
"execution_count": 2,
"id": "eef67659e94045c5",
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading Model to directory: /mnt/workspace/.cache/modelscope/qwen/Qwen1.5-4B-Chat\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-01-13 15:52:53,260 - modelscope - INFO - Model revision not specified, using default: [master] version.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-01-13 15:52:53,637 - modelscope - INFO - Creating symbolic link [/mnt/workspace/.cache/modelscope/qwen/Qwen1.5-4B-Chat].\n",
"2025-01-13 15:52:53,638 - modelscope - WARNING - Failed to create symbolic link /mnt/workspace/.cache/modelscope/qwen/Qwen1.5-4B-Chat for /mnt/workspace/.cache/modelscope/qwen/Qwen1___5-4B-Chat.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4523c5dd31ba411d95cc0ce9e5da8ded",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"llm created\n",
"Downloading Model to directory: /mnt/workspace/.cache/modelscope/damo/nlp_gte_sentence-embedding_chinese-base\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-01-13 15:53:01,651 - modelscope - INFO - Model revision not specified, using default: [master] version.\n",
"2025-01-13 15:53:01,894 - modelscope - INFO - initiate model from /mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-base\n",
"2025-01-13 15:53:01,895 - modelscope - INFO - initiate model from location /mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-base.\n",
"2025-01-13 15:53:01,898 - modelscope - INFO - initialize model from /mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-base\n",
"2025-01-13 15:53:02,532 - modelscope - WARNING - No preprocessor field found in cfg.\n",
"2025-01-13 15:53:02,533 - modelscope - WARNING - No val key and type key found in preprocessor domain of configuration.json file.\n",
"2025-01-13 15:53:02,533 - modelscope - WARNING - Cannot find available config to build preprocessor at mode inference, current config: {'model_dir': '/mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-base'}. trying to build by task and model information.\n",
"2025-01-13 15:53:02,588 - modelscope - WARNING - No preprocessor field found in cfg.\n",
"2025-01-13 15:53:02,588 - modelscope - WARNING - No val key and type key found in preprocessor domain of configuration.json file.\n",
"2025-01-13 15:53:02,589 - modelscope - WARNING - Cannot find available config to build preprocessor at mode inference, current config: {'model_dir': '/mnt/workspace/.cache/modelscope/hub/damo/nlp_gte_sentence-embedding_chinese-base', 'sequence_length': 128}. trying to build by task and model information.\n",
"/root/miniconda3/envs/modelscope_1.21/lib/python3.9/site-packages/transformers/modeling_utils.py:1044: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
" warnings.warn(\n",
"/root/miniconda3/envs/modelscope_1.21/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:628: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
" warnings.warn(\n",
"/root/miniconda3/envs/modelscope_1.21/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:633: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"作为一所历史悠久的综合性研究型大学,西安交通大学有着丰富的校训文化。其中,\"厚德博学,求是创新\"是其最为人所知的校训之一。这句校训不仅体现了学校的教育理念,也反映了学校对学生的期望和要求。此外,西安交通大学还有一句著名的校训:\"明德尚志,自强不息\",这也是学校对学生的一种激励和引导。这两句校训都强调了教育的重要性,以及学生应该具备的道德品质和自我提升的精神。\n"
]
}
],
"source": [
"import logging\n",
"import sys\n",
@@ -93,9 +174,7 @@
"from llama_index.core import (\n",
" SimpleDirectoryReader,\n",
" VectorStoreIndex,\n",
" Settings,\n",
" ServiceContext,\n",
" set_global_service_context,\n",
" Settings\n",
")\n",
"from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding\n",
"from llama_index.core.prompts import PromptTemplate\n",
@@ -176,9 +255,8 @@
"\n",
"embedding_model = \"damo/nlp_gte_sentence-embedding_chinese-base\"\n",
"embeddings = ModelScopeEmbeddings4LlamaIndex(model_id=embedding_model)\n",
"service_context = ServiceContext.from_defaults(embed_model=embeddings, llm=llm)\n",
"set_global_service_context(service_context)\n",
"Settings.embed_model = embeddings\n",
"Settings.llm = llm\n",
"\n",
"# load example documents\n",
"documents = SimpleDirectoryReader(\"/mnt/workspace/custom_data/\").load_data()\n",
@@ -192,11 +270,7 @@
"# do query\n",
"response = query_engine.query(\"西安较大的校训是什么\")\n",
"print(response)\n"
],
"metadata": {
"collapsed": false
},
"id": "eef67659e94045c5"
]
}
],
"metadata": {

View File

@@ -11,6 +11,7 @@ from modelscope.cli.modelcard import ModelCardCMD
from modelscope.cli.pipeline import PipelineCMD
from modelscope.cli.plugins import PluginsCMD
from modelscope.cli.server import ServerCMD
from modelscope.cli.upload import UploadCMD
from modelscope.hub.api import HubApi
from modelscope.utils.logger import get_logger
@@ -25,6 +26,7 @@ def run_cmd():
subparsers = parser.add_subparsers(help='modelscope commands helpers')
DownloadCMD.define_args(subparsers)
UploadCMD.define_args(subparsers)
ClearCacheCMD.define_args(subparsers)
PluginsCMD.define_args(subparsers)
PipelineCMD.define_args(subparsers)

View File

@@ -1,12 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
from modelscope.hub.file_download import (dataset_file_download,
model_file_download)
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
snapshot_download)
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
def subparser_func(args):
@@ -88,6 +90,12 @@ class DownloadCMD(CLICommand):
default=None,
help='Glob patterns to exclude from files to download.'
'Ignored if file is specified')
parser.add_argument(
'--max-workers',
type=int,
default=DEFAULT_MAX_WORKERS,
help='The maximum number of workers to download files.')
parser.set_defaults(func=subparser_func)
def execute(self):
@@ -125,6 +133,7 @@ class DownloadCMD(CLICommand):
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
)
else: # download repo
snapshot_download(
@@ -134,32 +143,36 @@ class DownloadCMD(CLICommand):
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
max_workers=self.args.max_workers,
)
elif self.args.dataset:
dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
if len(self.args.files) == 1: # download single file
dataset_file_download(
self.args.dataset,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
revision=dataset_revision)
elif len(
self.args.files) > 1: # download specified multiple files.
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
revision=dataset_revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
)
else: # download repo
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
revision=dataset_revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
max_workers=self.args.max_workers,
)
else:
pass # noop

179
modelscope/cli/upload.py Normal file
View File

@@ -0,0 +1,179 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from argparse import ArgumentParser, _SubParsersAction
from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
from modelscope.utils.logger import get_logger
logger = get_logger()
def subparser_func(args):
""" Function which will be called for a specific sub parser.
"""
return UploadCMD(args)
class UploadCMD(CLICommand):
name = 'upload'
def __init__(self, args: _SubParsersAction):
self.args = args
@staticmethod
def define_args(parsers: _SubParsersAction):
parser: ArgumentParser = parsers.add_parser(UploadCMD.name)
parser.add_argument(
'repo_id',
type=str,
help='The ID of the repo to upload to (e.g. `username/repo-name`)')
parser.add_argument(
'local_path',
type=str,
nargs='?',
default=None,
help='Optional, '
'Local path to the file or folder to upload. Defaults to current directory.'
)
parser.add_argument(
'path_in_repo',
type=str,
nargs='?',
default=None,
help='Optional, '
'Path of the file or folder in the repo. Defaults to the relative path of the file or folder.'
)
parser.add_argument(
'--repo-type',
choices=REPO_TYPE_SUPPORT,
default=REPO_TYPE_MODEL,
help=
'Type of the repo to upload to (e.g. `dataset`, `model`). Defaults to be `model`.',
)
parser.add_argument(
'--include',
nargs='*',
type=str,
help='Glob patterns to match files to upload.')
parser.add_argument(
'--exclude',
nargs='*',
type=str,
help='Glob patterns to exclude from files to upload.')
parser.add_argument(
'--commit-message',
type=str,
default=None,
help='The message of commit. Default to be `None`.')
parser.add_argument(
'--commit-description',
type=str,
default=None,
help=
'The description of the generated commit. Default to be `None`.')
parser.add_argument(
'--token',
type=str,
default=None,
help=
'A User Access Token generated from https://modelscope.cn/my/myaccesstoken'
)
parser.add_argument(
'--max-workers',
type=int,
default=min(8,
os.cpu_count() + 4),
help='The number of workers to use for uploading files.')
parser.add_argument(
'--endpoint',
type=str,
default='https://www.modelscope.cn',
help='Endpoint for Modelscope service.')
parser.set_defaults(func=subparser_func)
def execute(self):
assert self.args.repo_id, '`repo_id` is required'
assert self.args.repo_id.count(
'/') == 1, 'repo_id should be in format of username/repo-name'
repo_name: str = self.args.repo_id.split('/')[-1]
self.repo_id = self.args.repo_id
# Check path_in_repo
if self.args.local_path is None and os.path.isfile(repo_name):
# Case 1: modelscope upload owner_name/test_repo
self.local_path = repo_name
self.path_in_repo = repo_name
elif self.args.local_path is None and os.path.isdir(repo_name):
# Case 2: modelscope upload owner_name/test_repo (run command line in the `repo_name` dir)
# => upload all files in current directory to remote root path
self.local_path = repo_name
self.path_in_repo = '.'
elif self.args.local_path is None:
# Case 3: user provided only a repo_id that does not match a local file or folder
# => the user must explicitly provide a local_path => raise exception
raise ValueError(
f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly."
)
elif self.args.path_in_repo is None and os.path.isfile(
self.args.local_path):
# Case 4: modelscope upload owner_name/test_repo /path/to/your_file.csv
# => upload it to remote root path with same name
self.local_path = self.args.local_path
self.path_in_repo = os.path.basename(self.args.local_path)
elif self.args.path_in_repo is None:
# Case 5: modelscope upload owner_name/test_repo /path/to/your_folder
# => upload all files in current directory to remote root path
self.local_path = self.args.local_path
self.path_in_repo = ''
else:
# Finally, if both paths are explicit
self.local_path = self.args.local_path
self.path_in_repo = self.args.path_in_repo
# Check token and login
# The cookies will be reused if the user has logged in before.
api = HubApi(endpoint=self.args.endpoint)
if self.args.token:
api.login(access_token=self.args.token)
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError(
'The `token` is not provided! '
'You can pass the `--token` argument, '
'or use api.login(access_token=`your_sdk_token`). '
'Your token is available at https://modelscope.cn/my/myaccesstoken'
)
if os.path.isfile(self.local_path):
commit_info = api.upload_file(
path_or_fileobj=self.local_path,
path_in_repo=self.path_in_repo,
repo_id=self.repo_id,
repo_type=self.args.repo_type,
commit_message=self.args.commit_message,
commit_description=self.args.commit_description,
)
elif os.path.isdir(self.local_path):
commit_info = api.upload_folder(
repo_id=self.repo_id,
folder_path=self.local_path,
path_in_repo=self.path_in_repo,
commit_message=self.args.commit_message,
commit_description=self.args.commit_description,
repo_type=self.args.repo_type,
allow_patterns=self.args.include,
ignore_patterns=self.args.exclude,
max_workers=self.args.max_workers,
)
else:
raise ValueError(f'{self.local_path} is not a valid local path')
logger.info(f'Upload finished, commit info: {commit_info}')

View File

@@ -3,6 +3,7 @@
import datetime
import functools
import io
import os
import pickle
import platform
@@ -13,13 +14,15 @@ from collections import defaultdict
from http import HTTPStatus
from http.cookiejar import CookieJar
from os.path import expanduser
from typing import Dict, List, Optional, Tuple, Union
from pathlib import Path
from typing import Any, BinaryIO, Dict, Iterable, List, Optional, Tuple, Union
from urllib.parse import urlencode
import json
import requests
from requests import Session
from requests.adapters import HTTPAdapter, Retry
from tqdm import tqdm
from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
API_HTTP_CLIENT_TIMEOUT,
@@ -29,6 +32,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
API_RESPONSE_FIELD_MESSAGE,
API_RESPONSE_FIELD_USERNAME,
DEFAULT_CREDENTIALS_PATH,
DEFAULT_MAX_WORKERS,
MODELSCOPE_CLOUD_ENVIRONMENT,
MODELSCOPE_CLOUD_USERNAME,
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
@@ -43,18 +47,27 @@ from modelscope.hub.errors import (InvalidParameter, NotExistError,
raise_for_http_status, raise_on_error)
from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.repository import Repository
from modelscope.hub.utils.utils import (get_endpoint, get_readable_folder_size,
get_release_datetime,
model_id_to_group_owner_name)
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
DEFAULT_REPOSITORY_REVISION,
MASTER_MODEL_BRANCH, META_FILES_FORMAT,
REPO_TYPE_MODEL, ConfigFields,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT, ConfigFields,
DatasetFormations, DatasetMetaFormats,
DatasetVisibilityMap, DownloadChannel,
DownloadMode, Frameworks, ModelFile,
Tasks, VirgoDatasetConfig)
from modelscope.utils.file_utils import get_file_hash, get_file_size
from modelscope.utils.logger import get_logger
from .utils.utils import (get_endpoint, get_readable_folder_size,
get_release_datetime, model_id_to_group_owner_name)
from modelscope.utils.repo_utils import (DATASET_LFS_SUFFIX,
DEFAULT_IGNORE_PATTERNS,
MODEL_LFS_SUFFIX, CommitInfo,
CommitOperation, CommitOperationAdd,
RepoUtils)
from modelscope.utils.thread_utils import thread_executor
logger = get_logger()
@@ -94,6 +107,8 @@ class HubApi:
getattr(self.session, method),
timeout=timeout))
self.upload_checker = UploadingCheck()
def login(
self,
access_token: Optional[str] = None,
@@ -194,7 +209,7 @@ class HubApi:
headers=self.builder_headers(self.headers))
handle_http_post_error(r, path, body)
raise_on_error(r.json())
model_repo_url = f'{get_endpoint()}/{model_id}'
model_repo_url = f'{self.endpoint}/{model_id}'
return model_repo_url
def delete_model(self, model_id: str):
@@ -739,13 +754,14 @@ class HubApi:
Args:
repo_id (`str`): The repo id to use
filename (`str`): The queried filename
filename (`str`): The queried filename, if the file exists in a sub folder,
please pass <sub-folder-name>/<file-name>
revision (`Optional[str]`): The repo revision
Returns:
The query result in bool value
"""
files = self.get_model_files(repo_id, revision=revision)
files = [file['Name'] for file in files]
files = self.get_model_files(repo_id, recursive=True, revision=revision)
files = [file['Path'] for file in files]
return filename in files
def create_dataset(self,
@@ -1180,10 +1196,577 @@ class HubApi:
return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
**headers}
def get_file_base_path(self, namespace: str, dataset_name: str) -> str:
return f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?'
def get_file_base_path(self, repo_id: str) -> str:
_namespace, _dataset_name = repo_id.split('/')
return f'{self.endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?'
# return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
def create_repo(
self,
repo_id: str,
*,
token: Union[str, bool, None] = None,
visibility: Optional[str] = 'public',
repo_type: Optional[str] = REPO_TYPE_MODEL,
chinese_name: Optional[str] = '',
license: Optional[str] = Licenses.APACHE_V2,
) -> str:
# TODO: exist_ok
if not repo_id:
raise ValueError('Repo id cannot be empty!')
if token:
self.login(access_token=token)
else:
logger.warning('No token provided, will use the cached token.')
repo_id_list = repo_id.split('/')
if len(repo_id_list) != 2:
raise ValueError('Invalid repo id, should be in the format of `owner_name/repo_name`')
namespace, repo_name = repo_id_list
if repo_type == REPO_TYPE_MODEL:
visibilities = {k: v for k, v in ModelVisibility.__dict__.items() if not k.startswith('__')}
visibility: int = visibilities.get(visibility.upper())
if visibility is None:
raise ValueError(f'Invalid visibility: {visibility}, '
f'supported visibilities: `public`, `private`, `internal`')
repo_url: str = self.create_model(
model_id=repo_id,
visibility=visibility,
license=license,
chinese_name=chinese_name,
)
elif repo_type == REPO_TYPE_DATASET:
visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')}
visibility: int = visibilities.get(visibility.upper())
if visibility is None:
raise ValueError(f'Invalid visibility: {visibility}, '
f'supported visibilities: `public`, `private`, `internal`')
repo_url: str = self.create_dataset(
dataset_name=repo_name,
namespace=namespace,
chinese_name=chinese_name,
license=license,
visibility=visibility,
)
else:
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
return repo_url
def create_commit(
self,
repo_id: str,
operations: Iterable[CommitOperation],
*,
commit_message: str,
commit_description: Optional[str] = None,
token: str = None,
repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
) -> CommitInfo:
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/commit/{revision}'
commit_message = commit_message or f'Commit to {repo_id}'
commit_description = commit_description or ''
if token:
self.login(access_token=token)
# Construct payload
payload = self._prepare_commit_payload(
operations=operations,
commit_message=commit_message,
)
# POST
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
response = requests.post(
url,
headers=self.builder_headers(self.headers),
data=json.dumps(payload),
cookies=cookies
)
resp = response.json()
if not resp['Success']:
commit_message = resp['Message']
logger.warning(f'{commit_message}')
return CommitInfo(
commit_url=url,
commit_message=commit_message,
commit_description=commit_description,
oid='',
)
def upload_file(
self,
*,
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
path_in_repo: str,
repo_id: str,
token: Union[str, None] = None,
repo_type: Optional[str] = REPO_TYPE_MODEL,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
buffer_size_mb: Optional[int] = 1,
tqdm_desc: Optional[str] = '[Uploading]',
disable_tqdm: Optional[bool] = False,
) -> CommitInfo:
if repo_type not in REPO_TYPE_SUPPORT:
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
if not path_or_fileobj:
raise ValueError('Path or file object cannot be empty!')
if isinstance(path_or_fileobj, (str, Path)):
path_or_fileobj = os.path.abspath(os.path.expanduser(path_or_fileobj))
path_in_repo = path_in_repo or os.path.basename(path_or_fileobj)
else:
# If path_or_fileobj is bytes or BinaryIO, then path_in_repo must be provided
if not path_in_repo:
raise ValueError('Arg `path_in_repo` cannot be empty!')
# Read file content if path_or_fileobj is a file-like object (BinaryIO)
# TODO: to be refined
if isinstance(path_or_fileobj, io.BufferedIOBase):
path_or_fileobj = path_or_fileobj.read()
self.upload_checker.check_file(path_or_fileobj)
self.upload_checker.check_normal_files(
file_path_list=[path_or_fileobj],
repo_type=repo_type,
)
if token:
self.login(access_token=token)
commit_message = (
commit_message if commit_message is not None else f'Upload {path_in_repo} to ModelScope hub'
)
if buffer_size_mb <= 0:
raise ValueError('Buffer size: `buffer_size_mb` must be greater than 0')
hash_info_d: dict = get_file_hash(
file_path_or_obj=path_or_fileobj,
buffer_size_mb=buffer_size_mb,
)
file_size: int = hash_info_d['file_size']
file_hash: str = hash_info_d['file_hash']
upload_res: dict = self._upload_blob(
repo_id=repo_id,
repo_type=repo_type,
sha256=file_hash,
size=file_size,
data=path_or_fileobj,
disable_tqdm=disable_tqdm,
tqdm_desc=tqdm_desc,
)
# Construct commit info and create commit
add_operation: CommitOperationAdd = CommitOperationAdd(
path_in_repo=path_in_repo,
path_or_fileobj=path_or_fileobj,
)
add_operation._upload_mode = 'lfs' if self.upload_checker.is_lfs(path_or_fileobj, repo_type) else 'normal'
add_operation._is_uploaded = upload_res['is_uploaded']
operations = [add_operation]
commit_info: CommitInfo = self.create_commit(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
commit_description=commit_description,
token=token,
repo_type=repo_type,
)
return commit_info
def upload_folder(
self,
*,
repo_id: str,
folder_path: Union[str, Path],
path_in_repo: Optional[str] = '',
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
token: Union[str, None] = None,
repo_type: Optional[str] = REPO_TYPE_MODEL,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = DEFAULT_MAX_WORKERS,
) -> CommitInfo:
if repo_type not in REPO_TYPE_SUPPORT:
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
allow_patterns = allow_patterns if allow_patterns else None
ignore_patterns = ignore_patterns if ignore_patterns else None
self.upload_checker.check_folder(folder_path)
# Ignore .git folder
if ignore_patterns is None:
ignore_patterns = []
elif isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
ignore_patterns += DEFAULT_IGNORE_PATTERNS
if token:
self.login(access_token=token)
commit_message = (
commit_message if commit_message is not None else f'Upload folder to {repo_id} on ModelScope hub'
)
commit_description = commit_description or 'Uploading folder'
# Get the list of files to upload, e.g. [('data/abc.png', '/path/to/abc.png'), ...]
prepared_repo_objects = HubApi._prepare_upload_folder(
folder_path=folder_path,
path_in_repo=path_in_repo,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
self.upload_checker.check_normal_files(
file_path_list = [item for _, item in prepared_repo_objects],
repo_type=repo_type,
)
@thread_executor(max_workers=max_workers, disable_tqdm=False)
def _upload_items(item_pair, **kwargs):
file_path_in_repo, file_path = item_pair
hash_info_d: dict = get_file_hash(
file_path_or_obj=file_path,
)
file_size: int = hash_info_d['file_size']
file_hash: str = hash_info_d['file_hash']
upload_res: dict = self._upload_blob(
repo_id=repo_id,
repo_type=repo_type,
sha256=file_hash,
size=file_size,
data=file_path,
disable_tqdm=False if file_size > 5 * 1024 * 1024 else True,
tqdm_desc='[Uploading ' + file_path_in_repo + ']',
)
return {
'file_path_in_repo': file_path_in_repo,
'file_path': file_path,
'is_uploaded': upload_res['is_uploaded'],
}
uploaded_items_list = _upload_items(
prepared_repo_objects,
repo_id=repo_id,
token=token,
repo_type=repo_type,
commit_message=commit_message,
commit_description=commit_description,
buffer_size_mb=1,
disable_tqdm=False,
)
logger.info(f'Uploading folder to {repo_id} finished')
# Construct commit info and create commit
operations = []
for item_d in uploaded_items_list:
prepared_path_in_repo: str = item_d['file_path_in_repo']
prepared_file_path: str = item_d['file_path']
is_uploaded: bool = item_d['is_uploaded']
opt = CommitOperationAdd(
path_in_repo=prepared_path_in_repo,
path_or_fileobj=prepared_file_path,
)
# check normal or lfs
opt._upload_mode = 'lfs' if self.upload_checker.is_lfs(prepared_file_path, repo_type) else 'normal'
opt._is_uploaded = is_uploaded
operations.append(opt)
self.create_commit(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
commit_description=commit_description,
token=token,
repo_type=repo_type,
)
# Construct commit info
commit_url = f'{self.endpoint}/api/v1/{repo_type}s/{repo_id}/commit/{DEFAULT_REPOSITORY_REVISION}'
return CommitInfo(
commit_url=commit_url,
commit_message=commit_message,
commit_description=commit_description,
oid='')
def _upload_blob(
self,
*,
repo_id: str,
repo_type: str,
sha256: str,
size: int,
data: Union[str, Path, bytes, BinaryIO],
disable_tqdm: Optional[bool] = False,
tqdm_desc: Optional[str] = '[Uploading]',
buffer_size_mb: Optional[int] = 1,
) -> dict:
res_d: dict = dict(
url=None,
is_uploaded=False,
status_code=None,
status_msg=None,
)
objects = [{'oid': sha256, 'size': size}]
upload_objects = self._validate_blob(
repo_id=repo_id,
repo_type=repo_type,
objects=objects,
)
# upload_object: {'url': 'xxx', 'oid': 'xxx'}
upload_object = upload_objects[0] if len(upload_objects) == 1 else None
if upload_object is None:
logger.info(f'Blob {sha256} has already uploaded, reuse it.')
res_d['is_uploaded'] = True
return res_d
cookies = ModelScopeConfig.get_cookies()
cookies = dict(cookies) if cookies else None
if cookies is None:
raise ValueError('Token does not exist, please login first.')
self.headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"})
headers = self.builder_headers(self.headers)
def read_in_chunks(file_object, pbar, chunk_size=buffer_size_mb * 1024 * 1024):
"""Lazy function (generator) to read a file piece by piece."""
while True:
ck = file_object.read(chunk_size)
if not ck:
break
pbar.update(len(ck))
yield ck
with tqdm(
total=size,
unit='B',
unit_scale=True,
desc=tqdm_desc,
disable=disable_tqdm
) as pbar:
if isinstance(data, (str, Path)):
with open(data, 'rb') as f:
response = requests.put(
upload_object['url'],
headers=headers,
data=read_in_chunks(f, pbar)
)
elif isinstance(data, bytes):
response = requests.put(
upload_object['url'],
headers=headers,
data=read_in_chunks(io.BytesIO(data), pbar)
)
elif isinstance(data, io.BufferedIOBase):
response = requests.put(
upload_object['url'],
headers=headers,
data=read_in_chunks(data, pbar)
)
else:
raise ValueError('Invalid data type to upload')
resp = response.json()
raise_on_error(resp)
res_d['url'] = upload_object['url']
res_d['status_code'] = resp['Code']
res_d['status_msg'] = resp['Message']
return res_d
def _validate_blob(
self,
*,
repo_id: str,
repo_type: str,
objects: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Check the blob has already uploaded.
True -- uploaded; False -- not uploaded.
Args:
repo_id (str): The repo id ModelScope.
repo_type (str): The repo type. `dataset`, `model`, etc.
objects (List[Dict[str, Any]]): The objects to check.
oid (str): The sha256 hash value.
size (int): The size of the blob.
Returns:
List[Dict[str, Any]]: The result of the check.
"""
# construct URL
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/info/lfs/objects/batch'
# build payload
payload = {
'operation': 'upload',
'objects': objects,
}
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
response = requests.post(
url,
headers=self.builder_headers(self.headers),
data=json.dumps(payload),
cookies=cookies
)
resp = response.json()
raise_on_error(resp)
upload_objects = [] # list of objects to upload, [{'url': 'xxx', 'oid': 'xxx'}, ...]
resp_objects = resp['Data']['objects']
for obj in resp_objects:
upload_objects.append(
{'url': obj['actions']['upload']['href'],
'oid': obj['oid']}
)
return upload_objects
@staticmethod
def _prepare_upload_folder(
folder_path: Union[str, Path],
path_in_repo: str,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
) -> List[Union[tuple, list]]:
folder_path = Path(folder_path).expanduser().resolve()
if not folder_path.is_dir():
raise ValueError(f"Provided path: '{folder_path}' is not a directory")
# List files from folder
relpath_to_abspath = {
path.relative_to(folder_path).as_posix(): path
for path in sorted(folder_path.glob('**/*')) # sorted to be deterministic
if path.is_file()
}
# Filter files
filtered_repo_objects = list(
RepoUtils.filter_repo_objects(
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
)
prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ''
prepared_repo_objects = [
(prefix + relpath, str(relpath_to_abspath[relpath]))
for relpath in filtered_repo_objects
]
return prepared_repo_objects
@staticmethod
def _prepare_commit_payload(
operations: Iterable[CommitOperation],
commit_message: str,
) -> Dict[str, Any]:
"""
Prepare the commit payload to be sent to the ModelScope hub.
"""
payload = {
'commit_message': commit_message,
'actions': []
}
nb_ignored_files = 0
# 2. Send operations, one per line
for operation in operations:
# Skip ignored files
if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
nb_ignored_files += 1
continue
# 2.a. Case adding a normal file
if isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'normal':
commit_action = {
'action': 'update' if operation._is_uploaded else 'create',
'path': operation.path_in_repo,
'type': 'normal',
'size': operation.upload_info.size,
'sha256': '',
'content': operation.b64content().decode(),
'encoding': 'base64',
}
payload['actions'].append(commit_action)
# 2.b. Case adding an LFS file
elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'lfs':
commit_action = {
'action': 'update' if operation._is_uploaded else 'create',
'path': operation.path_in_repo,
'type': 'lfs',
'size': operation.upload_info.size,
'sha256': operation.upload_info.sha256,
'content': '',
'encoding': '',
}
payload['actions'].append(commit_action)
else:
raise ValueError(
f'Unknown operation to commit. Operation: {operation}. Upload mode:'
f" {getattr(operation, '_upload_mode', None)}"
)
if nb_ignored_files > 0:
logger.info(f'Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).')
return payload
class ModelScopeConfig:
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
@@ -1213,12 +1796,11 @@ class ModelScopeConfig:
with open(cookies_path, 'rb') as f:
cookies = pickle.load(f)
for cookie in cookies:
if cookie.is_expired() and not ModelScopeConfig.cookie_expired_warning:
if cookie.name == 'm_session_id' and cookie.is_expired() and \
not ModelScopeConfig.cookie_expired_warning:
ModelScopeConfig.cookie_expired_warning = True
logger.debug(
'Authentication has expired, '
'please re-login with modelscope login --token "YOUR_SDK_TOKEN" '
'if you need to access private models or datasets.')
logger.warning('Authentication has expired, '
'please re-login for uploading or accessing controlled entities.')
return None
return cookies
return None
@@ -1327,3 +1909,85 @@ class ModelScopeConfig:
elif isinstance(user_agent, str):
ua += '; ' + user_agent
return ua
class UploadingCheck:
def __init__(
self,
max_file_count: int = 100_000,
max_file_count_in_dir: int = 10_000,
max_file_size: int = 50 * 1024 ** 3,
lfs_size_limit: int = 5 * 1024 * 1024,
normal_file_size_total_limit: int = 500 * 1024 * 1024,
):
self.max_file_count = max_file_count
self.max_file_count_in_dir = max_file_count_in_dir
self.max_file_size = max_file_size
self.lfs_size_limit = lfs_size_limit
self.normal_file_size_total_limit = normal_file_size_total_limit
def check_file(self, file_path_or_obj):
if isinstance(file_path_or_obj, (str, Path)):
if not os.path.exists(file_path_or_obj):
raise ValueError(f'File {file_path_or_obj} does not exist')
file_size: int = get_file_size(file_path_or_obj)
if file_size > self.max_file_size:
raise ValueError(f'File exceeds size limit: {self.max_file_size / (1024 ** 3)} GB')
def check_folder(self, folder_path: Union[str, Path]):
file_count = 0
dir_count = 0
if isinstance(folder_path, str):
folder_path = Path(folder_path)
for item in folder_path.iterdir():
if item.is_file():
file_count += 1
elif item.is_dir():
dir_count += 1
# Count items in subdirectories recursively
sub_file_count, sub_dir_count = self.check_folder(item)
if (sub_file_count + sub_dir_count) > self.max_file_count_in_dir:
raise ValueError(f'Directory {item} contains {sub_file_count + sub_dir_count} items '
f'and exceeds limit: {self.max_file_count_in_dir}')
file_count += sub_file_count
dir_count += sub_dir_count
if file_count > self.max_file_count:
raise ValueError(f'Total file count {file_count} and exceeds limit: {self.max_file_count}')
return file_count, dir_count
def is_lfs(self, file_path_or_obj: Union[str, Path, bytes, BinaryIO], repo_type: str) -> bool:
hit_lfs_suffix = True
if isinstance(file_path_or_obj, (str, Path)):
file_path_or_obj = Path(file_path_or_obj)
if not file_path_or_obj.exists():
raise ValueError(f'File {file_path_or_obj} does not exist')
if repo_type == REPO_TYPE_MODEL:
if file_path_or_obj.suffix not in MODEL_LFS_SUFFIX:
hit_lfs_suffix = False
elif repo_type == REPO_TYPE_DATASET:
if file_path_or_obj.suffix not in DATASET_LFS_SUFFIX:
hit_lfs_suffix = False
else:
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
file_size: int = get_file_size(file_path_or_obj)
return file_size > self.lfs_size_limit or hit_lfs_suffix
def check_normal_files(self, file_path_list: List[Union[str, Path]], repo_type: str) -> None:
normal_file_list = [item for item in file_path_list if not self.is_lfs(item, repo_type)]
total_size = sum([get_file_size(item) for item in normal_file_list])
if total_size > self.normal_file_size_total_limit:
raise ValueError(f'Total size of non-lfs files {total_size/(1024 * 1024)}MB '
f'and exceeds limit: {self.normal_file_size_total_limit/(1024 * 1024)}MB')

View File

@@ -39,6 +39,7 @@ def check_local_model_is_latest(
"""
try:
model_id = get_model_id_from_cache(model_root_path)
model_id = model_id.replace('___', '.')
# make headers
headers = {
'user-agent':

View File

@@ -1,5 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from pathlib import Path
@@ -33,6 +32,9 @@ MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION = 'MODELSCOPE_ENABLE_DEFAULT_HASH_VALI
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
TEMPORARY_FOLDER_NAME = '._____temp'
DEFAULT_MAX_WORKERS = min(8, os.cpu_count() + 4)
MODELSCOPE_SHOW_INDIVIDUAL_PROGRESS_THRESHOLD = int(
os.environ.get('MODELSCOPE_SHOW_INDIVIDUAL_PROGRESS_THRESHOLD', 50))
class Licenses(object):

View File

@@ -164,6 +164,7 @@ def _repo_file_download(
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None,
disable_tqdm: bool = False,
) -> Optional[str]: # pragma: no cover
if not repo_type:
@@ -278,6 +279,9 @@ def _repo_file_download(
dataset_name=name,
namespace=group_or_owner,
revision=revision)
else:
raise ValueError(f'Invalid repo type {repo_type}')
return download_file(url_to_download, file_to_download_meta,
temporary_cache_dir, cache, headers, cookies)
@@ -382,6 +386,7 @@ def parallel_download(
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
file_size: int = None,
disable_tqdm: bool = False,
):
# create temp file
with tqdm(
@@ -392,6 +397,7 @@ def parallel_download(
initial=0,
desc='Downloading [' + file_name + ']',
leave=True,
disable=disable_tqdm,
) as progress:
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
@@ -435,6 +441,7 @@ def http_get_model_file(
file_size: int,
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
disable_tqdm: bool = False,
):
"""Download remote file, will retry 5 times before giving up on errors.
@@ -451,6 +458,7 @@ def http_get_model_file(
cookies used to authentication the user, which is used for downloading private repos
headers(Dict[str, str], optional):
http headers to carry necessary info when requesting the remote file
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
Raises:
FileDownloadError: File download failed.
@@ -478,6 +486,7 @@ def http_get_model_file(
initial=0,
desc='Downloading [' + file_name + ']',
leave=True,
disable=disable_tqdm,
) as progress:
if file_size == 0:
# Avoid empty file server request
@@ -488,6 +497,8 @@ def http_get_model_file(
partial_length = 0
# download partial, continue download
if os.path.exists(temp_file_path):
# resuming from interrupted download is also considered as retry
has_retry = True
with open(temp_file_path, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
@@ -511,7 +522,9 @@ def http_get_model_file(
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
f.write(chunk)
hash_sha256.update(chunk)
# hash would be discarded in retry case anyway
if not has_retry:
hash_sha256.update(chunk)
break
except Exception as e: # no matter what happen, we will retry.
has_retry = True
@@ -519,7 +532,6 @@ def http_get_model_file(
retry.sleep()
# if anything went wrong, we would discard the real-time computed hash and return None
return None if has_retry else hash_sha256.hexdigest()
logger.debug('storing %s in cache at %s', url, local_dir)
def http_get_file(
@@ -604,9 +616,15 @@ def http_get_file(
os.replace(temp_file.name, os.path.join(local_dir, file_name))
def download_file(url, file_meta, temporary_cache_dir, cache, headers,
cookies):
file_digest = None
def download_file(
url,
file_meta,
temporary_cache_dir,
cache,
headers,
cookies,
disable_tqdm=False,
):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
file_digest = parallel_download(
@@ -615,7 +633,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers,
file_meta['Path'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'])
file_size=file_meta['Size'],
disable_tqdm=disable_tqdm,
)
else:
file_digest = http_get_model_file(
url,
@@ -623,7 +643,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers,
file_meta['Path'],
file_size=file_meta['Size'],
headers=headers,
cookies=cookies)
cookies=cookies,
disable_tqdm=disable_tqdm,
)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])

View File

@@ -4,32 +4,33 @@ import fnmatch
import os
import re
import uuid
from concurrent.futures import ThreadPoolExecutor
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from tqdm.auto import tqdm
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import \
MODELSCOPE_SHOW_INDIVIDUAL_PROGRESS_THRESHOLD
from modelscope.hub.errors import InvalidParameter
from modelscope.hub.file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
from modelscope.hub.utils.caching import ModelFileSystemCache
from modelscope.hub.utils.utils import (get_model_masked_directory,
model_id_to_group_owner_name)
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION,
DEFAULT_REPOSITORY_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.logger import get_logger
from .file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
from modelscope.utils.thread_utils import thread_executor
logger = get_logger()
def snapshot_download(
model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
model_id: str = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
@@ -40,6 +41,8 @@ def snapshot_download(
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8,
repo_id: str = None,
repo_type: Optional[str] = REPO_TYPE_MODEL,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -51,7 +54,10 @@ def snapshot_download(
user always has git and git-lfs installed, and properly configured.
Args:
model_id (str): A user or an organization name and a repo name separated by a `/`.
repo_id (str): A user or an organization name and a repo name separated by a `/`.
model_id (str): A user or an organization name and a model name separated by a `/`.
if `repo_id` is provided, `model_id` will be ignored.
repo_type (str, optional): The type of the repo, either 'model' or 'dataset'.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
@@ -87,9 +93,22 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
repo_id = repo_id or model_id
if not repo_id:
raise ValueError('Please provide a valid model_id or repo_id')
if repo_type not in REPO_TYPE_SUPPORT:
raise ValueError(
f'Invalid repo type: {repo_type}, only support: {REPO_TYPE_SUPPORT}'
)
if revision is None:
revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
return _snapshot_download(
model_id,
repo_type=REPO_TYPE_MODEL,
repo_id,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
@@ -233,7 +252,7 @@ def _snapshot_download(
if repo_type == REPO_TYPE_MODEL:
directory = os.path.abspath(
local_dir) if local_dir is not None else os.path.join(
system_cache, repo_id)
system_cache, 'models', *repo_id.split('/'))
print(f'Downloading Model to directory: {directory}')
revision_detail = _api.get_valid_revision_detail(
repo_id, revision=revision, cookies=cookies)
@@ -294,7 +313,7 @@ def _snapshot_download(
elif repo_type == REPO_TYPE_DATASET:
directory = os.path.abspath(
local_dir) if local_dir else os.path.join(
system_cache, 'datasets', repo_id)
system_cache, 'datasets', *repo_id.split('/'))
print(f'Downloading Dataset to directory: {directory}')
group_or_owner, name = model_id_to_group_owner_name(repo_id)
@@ -393,21 +412,6 @@ def _get_valid_regex_pattern(patterns: List[str]):
return None
def thread_download(func, iterable, max_workers, **kwargs):
# Create a tqdm progress bar with the total number of files to fetch
with tqdm(
total=len(iterable),
desc=f'Fetching {len(iterable)} files') as pbar:
# Define a wrapper function to update the progress bar
def progress_wrapper(*args, **kwargs):
result = func(*args, **kwargs)
pbar.update(1)
return result
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(progress_wrapper, iterable)
def _download_file_lists(
repo_files: List[str],
cache: ModelFileSystemCache,
@@ -479,6 +483,7 @@ def _download_file_lists(
else:
filtered_repo_files.append(repo_file)
@thread_executor(max_workers=max_workers, disable_tqdm=False)
def _download_single_file(repo_file):
if repo_type == REPO_TYPE_MODEL:
url = get_file_download_url(
@@ -495,10 +500,21 @@ def _download_file_lists(
raise InvalidParameter(
f'Invalid repo type: {repo_type}, supported types: {REPO_TYPE_SUPPORT}'
)
download_file(url, repo_file, temporary_cache_dir, cache, headers,
cookies)
disable_tqdm = len(
filtered_repo_files
) > MODELSCOPE_SHOW_INDIVIDUAL_PROGRESS_THRESHOLD # noqa
download_file(
url,
repo_file,
temporary_cache_dir,
cache,
headers,
cookies,
disable_tqdm=disable_tqdm,
)
if len(filtered_repo_files) > 0:
thread_download(_download_single_file, filtered_repo_files,
max_workers)
logger.info(
f'Got {len(filtered_repo_files)} files, start to download ...')
_download_single_file(filtered_repo_files)
logger.info(f"Download {repo_type} '{repo_id}' successfully.")

View File

@@ -10,8 +10,11 @@ from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import read_config
from modelscope.utils.logger import get_logger
from modelscope.utils.streaming_output import StreamingOutputMixin
logger = get_logger()
__all__ = ['PolyLMForTextGeneration']
@@ -27,6 +30,9 @@ class PolyLMForTextGeneration(TorchModel, StreamingOutputMixin):
super().__init__(model_dir, *args, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, legacy=False, use_fast=False)
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
'that you can trust the external codes.')
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, device_map='auto', trust_remote_code=True)
self.model.eval()

View File

@@ -133,6 +133,11 @@ class OssDownloader(BaseDownloader):
raise f'meta-file: {dataset_name}.py not found on the modelscope hub.'
if dataset_py_script and dataset_formation == DatasetFormations.hf_compatible:
if trust_remote_code:
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make '
'sure that you can trust the external codes.')
self.dataset = hf_load_dataset(
dataset_py_script,
name=subset_name,

View File

@@ -71,6 +71,11 @@ class LocalDataLoaderManager(DataLoaderManager):
# Select local data loader
# TODO: more loaders to be supported.
if data_loader_type == LocalDataLoaderType.HF_DATA_LOADER:
if trust_remote_code:
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make '
'sure that you can trust the external codes.')
# Build huggingface data loader and return dataset.
return hf_data_loader(
dataset_name,
@@ -110,6 +115,10 @@ class RemoteDataLoaderManager(DataLoaderManager):
# To use the huggingface data loader
if data_loader_type == RemoteDataLoaderType.HF_DATA_LOADER:
if trust_remote_code:
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make '
'sure that you can trust the external codes.')
dataset_ret = hf_data_loader(
dataset_name,
name=subset_name,

View File

@@ -6,7 +6,8 @@ from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
Sequence, Union)
import numpy as np
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from datasets import (Dataset, DatasetDict, Features, IterableDataset,
IterableDatasetDict)
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
from datasets.utils.file_utils import is_relative_path
@@ -163,6 +164,7 @@ class MsDataset:
download_mode: Optional[DownloadMode] = DownloadMode.
REUSE_DATASET_IF_EXISTS,
cache_dir: Optional[str] = MS_DATASETS_CACHE,
features: Optional[Features] = None,
use_streaming: Optional[bool] = False,
stream_batch_size: Optional[int] = 1,
custom_cfg: Optional[Config] = Config(),
@@ -237,6 +239,11 @@ class MsDataset:
if not namespace or not dataset_name:
raise 'The dataset_name should be in the form of `namespace/dataset_name` or `dataset_name`.'
if trust_remote_code:
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make sure that '
'you can trust the external codes.')
# Init context config
dataset_context_config = DatasetContextConfig(
dataset_name=dataset_name,
@@ -300,7 +307,7 @@ class MsDataset:
data_files=data_files,
split=split,
cache_dir=cache_dir,
features=None,
features=features,
download_config=None,
download_mode=download_mode.value,
revision=version,
@@ -329,6 +336,9 @@ class MsDataset:
return dataset_inst
elif hub == Hubs.virgo:
warnings.warn(
'The option `Hubs.virgo` is deprecated, '
'will be removed in the future version.', DeprecationWarning)
from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
from modelscope.utils.constant import VirgoDatasetConfig
# Rewrite the namespace, version and cache_dir for virgo dataset.
@@ -390,8 +400,10 @@ class MsDataset:
"""
warnings.warn(
'upload is deprecated, please use git command line to upload the dataset.',
DeprecationWarning)
'The function `upload` is deprecated, '
'please use git command '
'or modelscope.hub.api.HubApi.upload_folder '
'or modelscope.hub.api.HubApi.upload_file.', DeprecationWarning)
if not object_name:
raise ValueError('object_name cannot be empty!')
@@ -441,7 +453,7 @@ class MsDataset:
"""
warnings.warn(
'upload is deprecated, please use git command line to upload the dataset.',
'The function `clone_meta` is deprecated, please use git command line to clone the repo.',
DeprecationWarning)
_repo = DatasetRepository(
@@ -482,6 +494,12 @@ class MsDataset:
None
"""
warnings.warn(
'The function `upload_meta` is deprecated, '
'please use git command '
'or CLI `modelscope upload owner_name/repo_name ...`.',
DeprecationWarning)
_repo = DatasetRepository(
repo_work_dir=dataset_work_dir,
dataset_id='',

View File

@@ -41,19 +41,19 @@ from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
_MODULE_TO_EXTENSIONS,
_PACKAGED_DATASETS_MODULES)
from datasets.utils import file_utils
from datasets.utils.file_utils import (OfflineModeIsEnabled,
_raise_if_offline_mode_is_enabled,
from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled,
cached_path, is_local_path,
is_relative_path,
relative_to_absolute_path)
from datasets.utils.info_utils import is_small_dataset
from datasets.utils.metadata import MetadataConfigs
from datasets.utils.py_utils import get_imports, map_nested
from datasets.utils.py_utils import get_imports
from datasets.utils.track import tracked_str
from fsspec import filesystem
from fsspec.core import _un_chain
from fsspec.utils import stringify_path
from huggingface_hub import (DatasetCard, DatasetCardData)
from huggingface_hub.errors import OfflineModeIsEnabled
from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
from packaging import version
@@ -62,7 +62,8 @@ from modelscope import HubApi
from modelscope.hub.utils.utils import get_endpoint
from modelscope.msdatasets.utils.hf_file_utils import get_from_cache_ms
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE
from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DEFAULT_DATASET_REVISION
from modelscope.utils.import_utils import has_attr_in_class
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -97,7 +98,7 @@ def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) ->
if is_relative_path(url_or_filename):
# append the relative path to the base_path
# url_or_filename = url_or_path_join(self._base_path, url_or_filename)
revision = revision or 'master'
revision = revision or DEFAULT_DATASET_REVISION
# Note: make sure the FilePath is the last param
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}
params: str = urlencode(params)
@@ -162,7 +163,7 @@ def _dataset_info(
dataset_hub_id, dataset_type = _api.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_namespace)
revision: str = revision or 'master'
revision: str = revision or DEFAULT_DATASET_REVISION
data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id,
revision=revision,
files_metadata=files_metadata,
@@ -234,7 +235,7 @@ def _list_repo_tree(
while True:
data: dict = _api.list_repo_tree(dataset_name=_dataset_name,
namespace=_namespace,
revision=revision or 'master',
revision=revision or DEFAULT_DATASET_REVISION,
root_path=path_in_repo or None,
recursive=True,
page_number=page_number,
@@ -277,7 +278,7 @@ def _get_paths_info(
dataset_hub_id, dataset_type = _api.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_namespace)
revision: str = revision or 'master'
revision: str = revision or DEFAULT_DATASET_REVISION
data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id,
revision=revision,
files_metadata=False,
@@ -296,6 +297,29 @@ def _get_paths_info(
]
def _download_repo_file(repo_id: str, path_in_repo: str, download_config: DownloadConfig, revision: str):
_api = HubApi()
_namespace, _dataset_name = repo_id.split('/')
if download_config and download_config.download_desc is None:
download_config.download_desc = f'Downloading [{path_in_repo}]'
try:
url_or_filename = _api.get_dataset_file_url(
file_name=path_in_repo,
dataset_name=_dataset_name,
namespace=_namespace,
revision=revision,
extension_filter=False,
)
repo_file_path = cached_path(
url_or_filename=url_or_filename, download_config=download_config)
except FileNotFoundError as e:
repo_file_path = ''
logger.error(e)
return repo_file_path
def get_fs_token_paths(
urlpath,
storage_options=None,
@@ -536,9 +560,6 @@ def _get_data_patterns(
def get_module_without_script(self) -> DatasetModule:
_ms_api = HubApi()
_repo_id: str = self.name
_namespace, _dataset_name = _repo_id.split('/')
# hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
# self.name,
@@ -549,28 +570,20 @@ def get_module_without_script(self) -> DatasetModule:
# even if metadata_configs is not None (which means that we will resolve files for each config later)
# we cannot skip resolving all files because we need to infer module name by files extensions
# revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime
revision = self.revision or 'master'
revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip(
'/')
repo_id: str = self.name
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = 'Downloading [README.md]'
try:
url_or_filename = _ms_api.get_dataset_file_url(
file_name='README.md',
dataset_name=_dataset_name,
namespace=_namespace,
revision=revision,
extension_filter=False,
)
dataset_readme_path = cached_path(
url_or_filename=url_or_filename, download_config=download_config)
dataset_card_data = DatasetCard.load(Path(dataset_readme_path)).data
except FileNotFoundError:
dataset_card_data = DatasetCardData()
dataset_readme_path = _download_repo_file(
repo_id=repo_id,
path_in_repo='README.md',
download_config=download_config,
revision=revision)
dataset_card_data = DatasetCard.load(Path(dataset_readme_path)).data if dataset_readme_path else DatasetCardData()
subset_name: str = download_config.storage_options.get('name', None)
metadata_configs = MetadataConfigs.from_dataset_card_data(
@@ -646,10 +659,7 @@ def get_module_without_script(self) -> DatasetModule:
builder_kwargs = {
# "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"),
'base_path':
_ms_api.get_file_base_path(
namespace=_namespace,
dataset_name=_dataset_name,
),
HubApi().get_file_base_path(repo_id=repo_id),
'repo_id':
self.name,
'dataset_name':
@@ -760,20 +770,22 @@ def _download_additional_modules(
def get_module_with_script(self) -> DatasetModule:
_api = HubApi()
_dataset_name: str = self.name.split('/')[-1]
_namespace: str = self.name.split('/')[0]
repo_id: str = self.name
_namespace, _dataset_name = repo_id.split('/')
revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
script_file_name = f'{_dataset_name}.py'
script_url: str = _api.get_dataset_file_url(
file_name=script_file_name,
dataset_name=_dataset_name,
namespace=_namespace,
revision=self.revision,
extension_filter=False,
local_script_path = _download_repo_file(
repo_id=repo_id,
path_in_repo=script_file_name,
download_config=self.download_config,
revision=revision,
)
local_script_path = cached_path(
url_or_filename=script_url, download_config=self.download_config)
if not local_script_path:
raise FileNotFoundError(
f'Cannot find {script_file_name} in {repo_id} at revision {revision}. '
f'Please create {script_file_name} in the repo.'
)
dataset_infos_path = None
# try:
@@ -790,22 +802,19 @@ def get_module_with_script(self) -> DatasetModule:
# logger.info(f'Cannot find dataset_infos.json: {e}')
# dataset_infos_path = None
dataset_readme_url: str = _api.get_dataset_file_url(
file_name='README.md',
dataset_name=_dataset_name,
namespace=_namespace,
revision=self.revision,
extension_filter=False,
dataset_readme_path = _download_repo_file(
repo_id=repo_id,
path_in_repo='README.md',
download_config=self.download_config,
revision=revision
)
dataset_readme_path = cached_path(
url_or_filename=dataset_readme_url, download_config=self.download_config)
imports = get_imports(local_script_path)
local_imports = _download_additional_modules(
name=self.name,
name=repo_id,
dataset_name=_dataset_name,
namespace=_namespace,
revision=self.revision,
revision=revision,
imports=imports,
download_config=self.download_config,
)
@@ -821,11 +830,13 @@ def get_module_with_script(self) -> DatasetModule:
dynamic_modules_path=dynamic_modules_path,
module_namespace='datasets',
subdirectory_name=hash,
name=self.name,
name=repo_id,
)
if not os.path.exists(importable_file_path):
trust_remote_code = resolve_trust_remote_code(trust_remote_code=self.trust_remote_code, repo_id=self.name)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {repo_id}. Please make sure that '
'you can trust the external codes.')
_create_importable_file(
local_path=local_script_path,
local_imports=local_imports,
@@ -833,12 +844,12 @@ def get_module_with_script(self) -> DatasetModule:
dynamic_modules_path=dynamic_modules_path,
module_namespace='datasets',
subdirectory_name=hash,
name=self.name,
name=repo_id,
download_mode=self.download_mode,
)
else:
raise ValueError(
f'Loading {self.name} requires you to execute the dataset script in that'
f'Loading {repo_id} requires you to execute the dataset script in that'
' repo on your local machine. Make sure you have read the code there to avoid malicious use, then'
' set the option `trust_remote_code=True` to remove this error.'
)
@@ -846,14 +857,14 @@ def get_module_with_script(self) -> DatasetModule:
dynamic_modules_path=dynamic_modules_path,
module_namespace='datasets',
subdirectory_name=hash,
name=self.name,
name=repo_id,
)
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {
# "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
'base_path': _api.get_file_base_path(namespace=_namespace, dataset_name=_dataset_name),
'repo_id': self.name,
'base_path': HubApi().get_file_base_path(repo_id=repo_id),
'repo_id': repo_id,
}
return DatasetModule(module_path, hash, builder_kwargs)
@@ -925,6 +936,11 @@ class DatasetsWrapperHF:
verification_mode or VerificationMode.BASIC_CHECKS
) if not save_infos else VerificationMode.ALL_CHECKS)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
# Create a dataset builder
builder_instance = DatasetsWrapperHF.load_dataset_builder(
path=path,
@@ -1052,6 +1068,11 @@ class DatasetsWrapperHF:
) if download_config else DownloadConfig()
download_config.storage_options.update(storage_options)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
dataset_module = DatasetsWrapperHF.dataset_module_factory(
path,
revision=revision,
@@ -1126,9 +1147,11 @@ class DatasetsWrapperHF:
) -> DatasetModule:
subset_name: str = download_kwargs.pop('name', None)
revision = revision or DEFAULT_DATASET_REVISION
if download_config is None:
download_config = DownloadConfig(**download_kwargs)
download_config.storage_options.update({'name': subset_name})
download_config.storage_options.update({'revision': revision})
if download_config and download_config.cache_dir is None:
download_config.cache_dir = MS_DATASETS_CACHE
@@ -1160,6 +1183,10 @@ class DatasetsWrapperHF:
# -> the module from the python file in the dataset repository
# - if path has one "/" and is dataset repository on the HF hub without a python file
# -> use a packaged module (csv, text etc.) based on content of the repository
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
# Try packaged
if path in _PACKAGED_DATASETS_MODULES:
@@ -1197,7 +1224,7 @@ class DatasetsWrapperHF:
data_files=data_files,
download_mode=download_mode).get_module()
# Try remotely
elif is_relative_path(path) and path.count('/') <= 1:
elif is_relative_path(path) and path.count('/') == 1:
try:
_raise_if_offline_mode_is_enabled()
@@ -1236,6 +1263,15 @@ class DatasetsWrapperHF:
)
else:
raise e
dataset_readme_path = _download_repo_file(
repo_id=path,
path_in_repo='README.md',
download_config=download_config,
revision=revision,
)
commit_hash = os.path.basename(os.path.dirname(dataset_readme_path))
if filename in [
sibling.rfilename for sibling in dataset_info.siblings
]: # contains a dataset script
@@ -1264,26 +1300,54 @@ class DatasetsWrapperHF:
# This fails when the dataset has multiple configs and a default config and
# the user didn't specify a configuration name (_require_default_config_name=True).
try:
if has_attr_in_class(HubDatasetModuleFactoryWithParquetExport, 'revision'):
return HubDatasetModuleFactoryWithParquetExport(
path,
revision=revision,
download_config=download_config).get_module()
return HubDatasetModuleFactoryWithParquetExport(
path,
download_config=download_config,
revision=dataset_info.sha).get_module()
commit_hash=commit_hash,
download_config=download_config).get_module()
except Exception as e:
logger.error(e)
# Otherwise we must use the dataset script if the user trusts it
# To be adapted to the old version of datasets
if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'):
return HubDatasetModuleFactoryWithScript(
path,
revision=revision,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
return HubDatasetModuleFactoryWithScript(
path,
revision=revision,
commit_hash=commit_hash,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
else:
# To be adapted to the old version of datasets
if has_attr_in_class(HubDatasetModuleFactoryWithoutScript, 'revision'):
return HubDatasetModuleFactoryWithoutScript(
path,
revision=revision,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
return HubDatasetModuleFactoryWithoutScript(
path,
revision=revision,
commit_hash=commit_hash,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
@@ -1292,6 +1356,7 @@ class DatasetsWrapperHF:
except Exception as e1:
# All the attempts failed, before raising the error we should check if the module is already cached
logger.error(f'>> Error loading {path}: {e1}')
try:
return CachedDatasetModuleFactory(
path,

View File

@@ -4,7 +4,6 @@ from __future__ import print_function
import multiprocessing
import os
import oss2
from datasets.utils.file_utils import hash_url_to_filename
from modelscope.hub.api import HubApi
@@ -40,6 +39,7 @@ class OssUtilities:
self.multipart_threshold = 50 * 1024 * 1024
self.max_retries = 3
import oss2
self.resumable_store_download = oss2.ResumableDownloadStore(
root=self.resumable_store_root_path)
self.resumable_store_upload = oss2.ResumableStore(
@@ -47,6 +47,8 @@ class OssUtilities:
self.api = HubApi()
def _do_init(self, oss_config):
import oss2
self.key = oss_config[ACCESS_ID]
self.secret = oss_config[ACCESS_SECRET]
self.token = oss_config[SECURITY_TOKEN]
@@ -78,6 +80,7 @@ class OssUtilities:
def download(self, oss_file_name: str,
download_config: DataDownloadConfig):
import oss2
cache_dir = download_config.cache_dir
candidate_key = os.path.join(self.oss_dir, oss_file_name)
candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name)
@@ -126,6 +129,7 @@ class OssUtilities:
def upload(self, oss_object_name: str, local_file_path: str,
indicate_individual_progress: bool,
upload_mode: UploadMode) -> str:
import oss2
retry_count = 0
object_key = os.path.join(self.oss_dir, oss_object_name)

View File

@@ -1,8 +1,11 @@
from typing import List, Union
from modelscope import get_logger
from modelscope.pipelines.accelerate.base import InferFramework
from modelscope.utils.import_utils import is_vllm_available
logger = get_logger()
class Vllm(InferFramework):
@@ -27,6 +30,9 @@ class Vllm(InferFramework):
if not Vllm.check_gpu_compatibility(8) and (dtype
in ('bfloat16', 'auto')):
dtype = 'float16'
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {self.model_dir}. Please make '
'sure that you can trust the external codes.')
self.model = LLM(
self.model_dir,
dtype=dtype,

View File

@@ -108,30 +108,7 @@ def pipeline(task: str = None,
"""
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
if task is not None and task.lower() in [
Tasks.text_generation, Tasks.chat
]:
# if not specified, prefer llm pipeline for aforementioned tasks
if prefer_llm_pipeline is None:
prefer_llm_pipeline = True
# for llm pipeline, if llm_framework is not specified, default to swift instead
# TODO: port the swift infer based on transformer into ModelScope
if prefer_llm_pipeline and kwargs.get('llm_framework') is None:
kwargs['llm_framework'] = 'swift'
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
if pipeline_name is None and prefer_llm_pipeline:
pipeline_name = external_engine_for_llm_checker(
model, model_revision, kwargs)
if pipeline_name is None:
model = normalize_model_input(
model,
model_revision,
third_party=third_party,
ignore_file_pattern=ignore_file_pattern)
pipeline_props = {'type': pipeline_name}
if pipeline_name is None:
# get default pipeline for this task
if isinstance(model, str) \
@@ -142,16 +119,47 @@ def pipeline(task: str = None,
model, revision=model_revision) if isinstance(
model, str) else read_config(
model[0], revision=model_revision)
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model, cfg.get('allow_remote', False))
pipeline_name = external_engine_for_llm_checker(
model, model_revision,
kwargs) if prefer_llm_pipeline else None
if pipeline_name is not None:
if cfg:
pipeline_name = cfg.safe_get('pipeline',
{}).get('type', None)
if pipeline_name is None:
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
# if not specified in both args and configuration.json, prefer llm pipeline for aforementioned tasks
if task is not None and task.lower() in [
Tasks.text_generation, Tasks.chat
]:
if prefer_llm_pipeline is None:
prefer_llm_pipeline = True
# for llm pipeline, if llm_framework is not specified, default to swift instead
# TODO: port the swift infer based on transformer into ModelScope
if prefer_llm_pipeline:
if kwargs.get('llm_framework') is None:
kwargs['llm_framework'] = 'swift'
pipeline_name = external_engine_for_llm_checker(
model, model_revision, kwargs)
if pipeline_name is None or pipeline_name != 'llm':
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
model = normalize_model_input(
model,
model_revision,
third_party=third_party,
ignore_file_pattern=ignore_file_pattern)
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model,
cfg.get('allow_remote', False))
if pipeline_name:
pipeline_props = {'type': pipeline_name}
else:
check_config(cfg)
pipeline_props = cfg.pipeline
elif model is not None:
# get pipeline info from Model object
first_model = model[0] if isinstance(model, list) else model
@@ -165,6 +173,8 @@ def pipeline(task: str = None,
pipeline_name, default_model_repo = get_default_pipeline_info(task)
model = normalize_model_input(default_model_repo, model_revision)
pipeline_props = {'type': pipeline_name}
else:
pipeline_props = {'type': pipeline_name}
pipeline_props['model'] = model
pipeline_props['device'] = device
@@ -223,8 +233,9 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
List[Model]],
revision: Optional[str],
kwargs: Dict[str, Any]) -> Optional[str]:
from .nlp.llm_pipeline import SWIFT_MODEL_ID_MAPPING, init_swift_model_mapping, ModelTypeHelper, LLMAdapterRegistry
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
from ..hub.check_model import get_model_id_from_cache
from swift.llm import get_model_info_meta
if isinstance(model, list):
model = model[0]
if not isinstance(model, str):
@@ -237,9 +248,17 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
else:
model_id = model
init_swift_model_mapping()
if model_id.lower() in SWIFT_MODEL_ID_MAPPING:
try:
info = get_model_info_meta(model_id)
model_type = info[0].model_type
except Exception as e:
logger.warning(
f'Cannot using llm_framework with {model_id}, '
f'ignoring llm_framework={self.llm_framework} : {e}')
model_type = None
if model_type:
return 'llm'
model_type = ModelTypeHelper.get(
model, revision, with_adapter=True, split='-', use_cache=True)
if LLMAdapterRegistry.contains(model_type):

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, Union
import torch
from modelscope import AutoModelForCausalLM
from modelscope import AutoModelForCausalLM, get_logger
from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models.base import Model
from modelscope.outputs import OutputKeys
@@ -13,6 +13,8 @@ from modelscope.pipelines.multi_modal.visual_question_answering_pipeline import
from modelscope.preprocessors import Preprocessor, load_image
from modelscope.utils.constant import Fields, Frameworks, Tasks
logger = get_logger()
@PIPELINES.register_module(
Tasks.visual_question_answering, module_name='ovis-vl')
@@ -35,6 +37,9 @@ class VisionChatPipeline(VisualQuestionAnsweringPipeline):
torch_dtype = kwargs.get('torch_dtype', torch.float16)
multimodal_max_length = kwargs.get('multimodal_max_length', 8192)
self.device = 'cuda' if device == 'gpu' else device
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model}. Please make '
'sure that you can trust the external codes.')
self.model = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype=torch_dtype,

View File

@@ -29,21 +29,9 @@ from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
logger = get_logger()
SWIFT_MODEL_ID_MAPPING = {}
SWIFT_FRAMEWORK = 'swift'
def init_swift_model_mapping():
from swift.llm.utils import MODEL_MAPPING
global SWIFT_MODEL_ID_MAPPING
if not SWIFT_MODEL_ID_MAPPING:
SWIFT_MODEL_ID_MAPPING = {
v['model_id_or_path'].lower(): k
for k, v in MODEL_MAPPING.items()
}
class LLMAdapterRegistry:
llm_format_map = {'qwen': [None, None, None]}
@@ -109,6 +97,9 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
assert base_model is not None, 'Cannot get adapter_cfg.model_id_or_path from configuration.json file.'
revision = self.cfg.safe_get('adapter_cfg.model_revision',
'master')
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {base_model}. Please make sure that you can '
'trust the external codes.')
base_model = Model.from_pretrained(
base_model,
revision,
@@ -146,6 +137,9 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
model) else snapshot_download(model)
# TODO: Temporary use of AutoModelForCausalLM
# Need to be updated into a universal solution
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
'that you can trust the external codes.')
model = AutoModelForCausalLM.from_pretrained(
model_dir,
device_map=self.device_map,
@@ -185,6 +179,9 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
self.llm_framework = llm_framework
if os.path.exists(kwargs['model']):
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {kwargs["model"]}. Please make sure '
'that you can trust the external codes.')
config = AutoConfig.from_pretrained(
kwargs['model'], trust_remote_code=True)
q_config = config.__dict__.get('quantization_config', None)
@@ -227,12 +224,12 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
def _init_swift(self, model_id, device) -> None:
from swift.llm import prepare_model_template
from swift.llm.utils import InferArguments
from swift.llm import InferArguments, get_model_info_meta
def format_messages(messages: Dict[str, List[Dict[str, str]]],
tokenizer: PreTrainedTokenizer,
**kwargs) -> Dict[str, torch.Tensor]:
inputs, _ = self.template.encode(get_example(messages))
inputs = self.template.encode(messages)
inputs.pop('labels', None)
if 'input_ids' in inputs:
input_ids = torch.tensor(inputs['input_ids'])[None]
@@ -265,12 +262,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
else:
return dict(system=system, prompt=prompt, history=history)
init_swift_model_mapping()
assert model_id.lower() in SWIFT_MODEL_ID_MAPPING,\
f'Invalid model id {model_id} or Swift framework does not support this model.'
args = InferArguments(
model_type=SWIFT_MODEL_ID_MAPPING[model_id.lower()])
args = InferArguments(model=model_id)
model, template = prepare_model_template(
args, device_map=self.device_map)
self.model = add_stream_generate(model)
@@ -440,6 +432,9 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
model_dir = self.model.model_dir
if tokenizer_class is None:
tokenizer_class = AutoTokenizer
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
'that you can trust the external codes.')
return tokenizer_class.from_pretrained(
model_dir, trust_remote_code=True)

View File

@@ -269,6 +269,9 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline):
if use_bf16:
default_torch_dtype = torch.bfloat16
torch_dtype = kwargs.get('torch_dtype', default_torch_dtype)
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
'that you can trust the external codes.')
model = Model.from_pretrained(
model_dir,
trust_remote_code=True,
@@ -285,6 +288,9 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline):
self.model = model
self.model.eval()
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {self.model.model_dir}. Please '
'make sure that you can trust the external codes.')
self.tokenizer = AutoTokenizer.from_pretrained(
self.model.model_dir, trust_remote_code=True)
@@ -328,6 +334,9 @@ class QWenChatPipeline(Pipeline):
bf16 = False
if isinstance(model, str):
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model}. Please make sure '
'that you can trust the external codes.')
self.tokenizer = AutoTokenizer.from_pretrained(
model, revision=revision, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
@@ -392,6 +401,9 @@ class QWenTextGenerationPipeline(Pipeline):
bf16 = False
if isinstance(model, str):
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model}. Please make sure '
'that you can trust the external codes.')
self.model = AutoModelForCausalLM.from_pretrained(
model,
device_map=device_map,

View File

@@ -416,6 +416,8 @@ class TokenClassificationTransformersPreprocessor(
offset_mapping = []
tokens = self.nlp_tokenizer.tokenizer.tokenize(text)
offset = 0
if getattr(self.nlp_tokenizer.tokenizer, 'do_lower_case', False):
text = text.lower()
for token in tokens:
is_start = (token[:2] != '##')
if is_start:

View File

@@ -230,6 +230,10 @@ template_info = [
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/dolphin-mistral',
),
TemplateInfo(
template_regex=f'.*{cases("dolphin3", "dolphin-3")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/dolphin3'),
# "phi"
TemplateInfo(
@@ -251,6 +255,12 @@ template_info = [
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi3',
),
TemplateInfo(
template_regex=
f'.*{cases("phi4", "phi-4")}{no_multi_modal()}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi4',
),
TemplateInfo(
template_regex=
f'.*{cases("phi")}{no_multi_modal()}.*',
@@ -591,7 +601,7 @@ template_info = [
template_regex=
f'.*{cases("deepseek")}.*{cases("v2")}{no("v2.5")}{no_multi_modal()}.*{chat_suffix}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepseek_v2',
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepseek-v2',
),
# deepseek_coder
@@ -623,6 +633,94 @@ template_info = [
template=TemplateType.telechat_v2,
template_regex=f'.*{cases("TeleChat")}.*{cases("v2")}.*'),
# tulu3
TemplateInfo(
template_regex=f'.*{cases("tulu3", "tulu-3")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/tulu3'),
# athene-v2
TemplateInfo(
template_regex=f'.*{cases("athene-v2")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/athene-v2'),
# granite
TemplateInfo(
template_regex=f'.*{cases("granite-guardian-3")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3-guardian'),
TemplateInfo(
template_regex=f'.*{cases("granite")}.*{cases("code")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite-code'),
TemplateInfo(
template_regex=f'.*{cases("granite-3.1")}.*{cases("2b", "8b")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.1-dense'),
TemplateInfo(
template_regex=f'.*{cases("granite-3.1")}.*{cases("1b", "3b")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.1-moe'),
TemplateInfo(
template_regex=f'.*{cases("granite-embedding")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite-embedding'),
TemplateInfo(
template_regex=f'.*{cases("granite-3")}.*{cases("2b", "8b")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3-dense'),
TemplateInfo(
template_regex=f'.*{cases("granite-3")}.*{cases("1b", "3b")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3-moe'),
# opencoder
TemplateInfo(
template_regex=f'.*{cases("opencoder")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/opencoder'),
# smollm
TemplateInfo(
template_regex=f'.*{cases("smollm2")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/smollm2'),
TemplateInfo(
template_regex=f'.*{cases("smollm")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/smollm'),
# 'aya'
TemplateInfo(
template_regex=f'.*{cases("aya-expanse")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/aya-expanse'),
TemplateInfo(
template_regex=f'.*{cases("aya")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/aya'),
# falcon
TemplateInfo(
template_regex=f'.*{cases("falcon3")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/falcon3'),
TemplateInfo(
template_regex=f'.*{cases("falcon")}.*{cases("-2")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/falcon2'),
TemplateInfo(
template_regex=f'.*{cases("falcon")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/falcon'),
# smallthinker
TemplateInfo(
template_regex=f'.*{cases("smallthinker")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/smallthinker'),
TemplateInfo(
template_regex=f'.*{cases("nomic-embed-text")}.*',
modelfile_prefix=
@@ -651,10 +749,6 @@ template_info = [
template_regex=f'.*{cases("starcoder")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/starcoder'),
TemplateInfo(
template_regex=f'.*{cases("granite")}.*{cases("code")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite-code'),
TemplateInfo(
template_regex=f'.*{cases("all-minilm")}.*',
modelfile_prefix=
@@ -663,10 +757,6 @@ template_info = [
template_regex=f'.*{cases("openchat")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/openchat'),
TemplateInfo(
template_regex=f'.*{cases("aya")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/aya'),
TemplateInfo(
template_regex=f'.*{cases("openhermes")}.*',
modelfile_prefix=
@@ -687,10 +777,6 @@ template_info = [
template_regex=f'.*{cases("xwin")}.*{cases("lm")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/xwinlm'),
TemplateInfo(
template_regex=f'.*{cases("smollm")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/smollm'),
TemplateInfo(
template_regex=f'.*{cases("sqlcoder")}.*',
modelfile_prefix=
@@ -699,14 +785,6 @@ template_info = [
template_regex=f'.*{cases("starling-lm")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/starling-lm'),
TemplateInfo(
template_regex=f'.*{cases("falcon")}.*{cases("-2")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/falcon2'),
TemplateInfo(
template_regex=f'.*{cases("falcon")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/falcon'),
TemplateInfo(
template_regex=f'.*{cases("solar-pro")}.*',
modelfile_prefix=
@@ -820,6 +898,9 @@ class TemplateLoader:
model_id,
revision=kwargs.pop('revision', 'master'),
ignore_file_pattern=ignore_file_pattern)
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {model_dir}.'
' Please make sure that you can trust the external codes.'
)
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)

View File

@@ -3,12 +3,15 @@ import os
from types import MethodType
from typing import Any, Optional
from modelscope import get_logger
from modelscope.metainfo import Tasks
from modelscope.utils.ast_utils import INDEX_KEY
from modelscope.utils.import_utils import (LazyImportModule,
is_torch_available,
is_transformers_available)
logger = get_logger()
def can_load_by_ms(model_dir: str, task_name: Optional[str],
model_type: Optional[str]) -> bool:
@@ -91,6 +94,9 @@ def get_hf_automodel_class(model_dir: str,
if not os.path.exists(config_path):
return None
try:
logger.warning(
f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
'that you can trust the external codes.')
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if task_name is None:
automodel_class = get_default_automodel(config)

View File

@@ -5,13 +5,14 @@ from pathlib import Path
# Cache location
from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
from modelscope.utils.file_utils import get_modelscope_cache_dir
from modelscope.utils.file_utils import (get_dataset_cache_root,
get_modelscope_cache_dir)
MS_CACHE_HOME = get_modelscope_cache_dir()
DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'hub', 'datasets')
MS_DATASETS_CACHE = Path(
os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE))
# NOTE: removed `MS_DATASETS_CACHE` env,
# default is `~/.cache/modelscope/hub/datasets`
MS_DATASETS_CACHE = get_dataset_cache_root()
DOWNLOADED_DATASETS_DIR = 'downloads'
DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(MS_DATASETS_CACHE,

View File

@@ -1,9 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import hashlib
import inspect
import io
import os
from pathlib import Path
from shutil import Error, copy2, copystat
from typing import BinaryIO, Optional, Union
# TODO: remove this api, unify to flattened args
@@ -60,11 +62,16 @@ def get_model_cache_root() -> str:
def get_dataset_cache_root() -> str:
"""Get dataset raw file cache root path.
if `MODELSCOPE_CACHE` is set, return `MODELSCOPE_CACHE/datasets`,
else return `~/.cache/modelscope/hub/datasets`
Returns:
str: the modelscope dataset raw file cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'datasets')
if os.getenv('MODELSCOPE_CACHE'):
return os.path.join(get_modelscope_cache_dir(), 'datasets')
else:
return os.path.join(get_modelscope_cache_dir(), 'hub', 'datasets')
def get_dataset_cache_dir(dataset_id: str) -> str:
@@ -175,3 +182,85 @@ def copytree_py37(src,
if errors:
raise Error(errors)
return dst
def get_file_size(file_path_or_obj: Union[str, Path, bytes, BinaryIO]) -> int:
if isinstance(file_path_or_obj, (str, Path)):
file_path = Path(file_path_or_obj)
return file_path.stat().st_size
elif isinstance(file_path_or_obj, bytes):
return len(file_path_or_obj)
elif isinstance(file_path_or_obj, io.BufferedIOBase):
current_position = file_path_or_obj.tell()
file_path_or_obj.seek(0, os.SEEK_END)
size = file_path_or_obj.tell()
file_path_or_obj.seek(current_position)
return size
else:
raise TypeError(
'Unsupported type: must be string, Path, bytes, or io.BufferedIOBase'
)
def get_file_hash(
file_path_or_obj: Union[str, Path, bytes, BinaryIO],
buffer_size_mb: Optional[int] = 1,
tqdm_desc: Optional[str] = '[Calculating]',
disable_tqdm: Optional[bool] = True,
) -> dict:
from tqdm import tqdm
file_size = get_file_size(file_path_or_obj)
buffer_size = buffer_size_mb * 1024 * 1024
file_hash = hashlib.sha256()
chunk_hash_list = []
progress = tqdm(
total=file_size,
initial=0,
unit_scale=True,
dynamic_ncols=True,
unit='B',
desc=tqdm_desc,
disable=disable_tqdm,
)
if isinstance(file_path_or_obj, (str, Path)):
with open(file_path_or_obj, 'rb') as f:
while byte_chunk := f.read(buffer_size):
chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
file_hash.update(byte_chunk)
progress.update(len(byte_chunk))
file_hash = file_hash.hexdigest()
final_chunk_size = buffer_size
elif isinstance(file_path_or_obj, bytes):
file_hash.update(file_path_or_obj)
file_hash = file_hash.hexdigest()
chunk_hash_list.append(file_hash)
final_chunk_size = len(file_path_or_obj)
progress.update(final_chunk_size)
elif isinstance(file_path_or_obj, io.BufferedIOBase):
while byte_chunk := file_path_or_obj.read(buffer_size):
chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
file_hash.update(byte_chunk)
progress.update(len(byte_chunk))
file_hash = file_hash.hexdigest()
final_chunk_size = buffer_size
else:
progress.close()
raise ValueError(
'Input must be str, Path, bytes or a io.BufferedIOBase')
progress.close()
return {
'file_path_or_obj': file_path_or_obj,
'file_hash': file_hash,
'file_size': file_size,
'chunk_size': final_chunk_size,
'chunk_nums': len(chunk_hash_list),
'chunk_hash_list': chunk_hash_list,
}

View File

@@ -54,6 +54,8 @@ def read_config(model_id_or_path: str,
local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
elif os.path.isfile(model_id_or_path):
local_path = model_id_or_path
else:
return None
return Config.from_file(local_path)

View File

@@ -3,6 +3,7 @@
import ast
import functools
import importlib
import inspect
import logging
import os
import os.path as osp
@@ -480,3 +481,23 @@ class LazyImportModule(ModuleType):
importlib.import_module(module_name)
else:
logger.warning(f'{signature} not found in ast index file')
def has_attr_in_class(cls, attribute_name) -> bool:
"""
Determine if attribute in specific class.
Args:
cls: target class.
attribute_name: the attribute name.
Returns:
The attribute in the class or not.
"""
init_method = cls.__init__
signature = inspect.signature(init_method)
parameters = signature.parameters
param_names = list(parameters.keys())
return attribute_name in param_names

View File

@@ -451,6 +451,9 @@ def register_plugins_repo(plugins: List[str]) -> None:
def register_modelhub_repo(model_dir, allow_remote=False) -> None:
""" Try to install and import remote model from modelhub"""
if allow_remote:
logger.warning(
f'Use allow_remote=True. Will invoke codes from {model_dir}. Please make sure '
'that you can trust the external codes.')
try:
import_module_from_model_dir(model_dir)
except KeyError:

View File

@@ -0,0 +1,479 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022-present, the HuggingFace Inc. team.
import base64
import functools
import hashlib
import io
import os
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from fnmatch import fnmatch
from pathlib import Path
from typing import (BinaryIO, Callable, Generator, Iterable, Iterator, List,
Literal, Optional, TypeVar, Union)
from modelscope.utils.file_utils import get_file_hash
T = TypeVar('T')
# Always ignore `.git` and `.cache/modelscope` folders in commits
DEFAULT_IGNORE_PATTERNS = [
'.git',
'.git/*',
'*/.git',
'**/.git/**',
'.cache/modelscope',
'.cache/modelscope/*',
'*/.cache/modelscope',
'**/.cache/modelscope/**',
]
# Forbidden to commit these folders
FORBIDDEN_FOLDERS = ['.git', '.cache']
UploadMode = Literal['lfs', 'normal']
DATASET_LFS_SUFFIX = [
'.7z',
'.aac',
'.arrow',
'.audio',
'.bmp',
'.bin',
'.bz2',
'.flac',
'.ftz',
'.gif',
'.gz',
'.h5',
'.jack',
'.jpeg',
'.jpg',
'.jsonl',
'.joblib',
'.lz4',
'.msgpack',
'.npy',
'.npz',
'.ot',
'.parquet',
'.pb',
'.pickle',
'.pcm',
'.pkl',
'.raw',
'.rar',
'.sam',
'.tar',
'.tgz',
'.wasm',
'.wav',
'.webm',
'.webp',
'.zip',
'.zst',
'.tiff',
'.mp3',
'.mp4',
'.ogg',
]
MODEL_LFS_SUFFIX = [
'.7z',
'.arrow',
'.bin',
'.bz2',
'.ckpt',
'.ftz',
'.gz',
'.h5',
'.joblib',
'.mlmodel',
'.model',
'.msgpack',
'.npy',
'.npz',
'.onnx',
'.ot',
'.parquet',
'.pb',
'.pickle',
'.pkl',
'.pt',
'.pth',
'.rar',
'.safetensors',
'.tar',
'.tflite',
'.tgz',
'.wasm',
'.xz',
'.zip',
'.zst',
]
class RepoUtils:
@staticmethod
def filter_repo_objects(
items: Iterable[T],
*,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
key: Optional[Callable[[T], str]] = None,
) -> Generator[T, None, None]:
"""Filter repo objects based on an allowlist and a denylist.
Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
In the later case, `key` must be provided and specifies a function of one argument
that is used to extract a path from each element in iterable.
Patterns are Unix shell-style wildcards which are NOT regular expressions. See
https://docs.python.org/3/library/fnmatch.html for more details.
Args:
items (`Iterable`):
List of items to filter.
allow_patterns (`str` or `List[str]`, *optional*):
Patterns constituting the allowlist. If provided, item paths must match at
least one pattern from the allowlist.
ignore_patterns (`str` or `List[str]`, *optional*):
Patterns constituting the denylist. If provided, item paths must not match
any patterns from the denylist.
key (`Callable[[T], str]`, *optional*):
Single-argument function to extract a path from each item. If not provided,
the `items` must already be `str` or `Path`.
Returns:
Filtered list of objects, as a generator.
Raises:
:class:`ValueError`:
If `key` is not provided and items are not `str` or `Path`.
Example usage with paths:
```python
>>> # Filter only PDFs that are not hidden.
>>> list(RepoUtils.filter_repo_objects(
... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
... allow_patterns=["*.pdf"],
... ignore_patterns=[".*"],
... ))
["aaa.pdf"]
```
"""
allow_patterns = allow_patterns if allow_patterns else None
ignore_patterns = ignore_patterns if ignore_patterns else None
if isinstance(allow_patterns, str):
allow_patterns = [allow_patterns]
if isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
if allow_patterns is not None:
allow_patterns = [
RepoUtils._add_wildcard_to_directories(p)
for p in allow_patterns
]
if ignore_patterns is not None:
ignore_patterns = [
RepoUtils._add_wildcard_to_directories(p)
for p in ignore_patterns
]
if key is None:
def _identity(item: T) -> str:
if isinstance(item, str):
return item
if isinstance(item, Path):
return str(item)
raise ValueError(
f'Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.'
)
key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
for item in items:
path = key(item)
# Skip if there's an allowlist and path doesn't match any
if allow_patterns is not None and not any(
fnmatch(path, r) for r in allow_patterns):
continue
# Skip if there's a denylist and path matches any
if ignore_patterns is not None and any(
fnmatch(path, r) for r in ignore_patterns):
continue
yield item
@staticmethod
def _add_wildcard_to_directories(pattern: str) -> str:
if pattern[-1] == '/':
return pattern + '*'
return pattern
@dataclass
class CommitInfo(str):
"""Data structure containing information about a newly created commit.
Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`],
[`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific
to `str` is deprecated.
Attributes:
commit_url (`str`):
Url where to find the commit.
commit_message (`str`):
The summary (first line) of the commit that has been created.
commit_description (`str`):
Description of the commit that has been created. Can be empty.
oid (`str`):
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
pr_url (`str`, *optional*):
Url to the PR that has been created, if any. Populated when `create_pr=True`
is passed.
pr_revision (`str`, *optional*):
Revision of the PR that has been created, if any. Populated when
`create_pr=True` is passed. Example: `"refs/pr/1"`.
pr_num (`int`, *optional*):
Number of the PR discussion that has been created, if any. Populated when
`create_pr=True` is passed. Can be passed as `discussion_num` in
[`get_discussion_details`]. Example: `1`.
_url (`str`, *optional*):
Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by
[`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on
the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this
attribute. Please use `commit_url` instead.
"""
commit_url: str
commit_message: str
commit_description: str
oid: str
pr_url: Optional[str] = None
# Computed from `pr_url` in `__post_init__`
pr_revision: Optional[str] = field(init=False)
pr_num: Optional[str] = field(init=False)
# legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.)
_url: str = field(
repr=False, default=None) # type: ignore # defaults to `commit_url`
def __new__(cls,
*args,
commit_url: str,
_url: Optional[str] = None,
**kwargs):
return str.__new__(cls, _url or commit_url)
def to_dict(cls):
return {
'commit_url': cls.commit_url,
'commit_message': cls.commit_message,
'commit_description': cls.commit_description,
'oid': cls.oid,
'pr_url': cls.pr_url,
}
def git_hash(data: bytes) -> str:
"""
Computes the git-sha1 hash of the given bytes, using the same algorithm as git.
This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object
for more details.
Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the
pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of
the LFS file content when we want to compare LFS files.
Args:
data (`bytes`):
The data to compute the git-hash for.
Returns:
`str`: the git-hash of `data` as an hexadecimal string.
"""
_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {}
sha1 = functools.partial(hashlib.sha1, **_kwargs)
sha = sha1()
sha.update(b'blob ')
sha.update(str(len(data)).encode())
sha.update(b'\0')
sha.update(data)
return sha.hexdigest()
@dataclass
class UploadInfo:
"""
Dataclass holding required information to determine whether a blob
should be uploaded to the hub using the LFS protocol or the regular protocol
Args:
sha256 (`str`):
SHA256 hash of the blob
size (`int`):
Size in bytes of the blob
sample (`bytes`):
First 512 bytes of the blob
"""
sha256: str
size: int
sample: bytes
@classmethod
def from_path(cls, path: str):
file_hash_info: dict = get_file_hash(path)
size = file_hash_info['file_size']
sha = file_hash_info['file_hash']
sample = open(path, 'rb').read(512)
return cls(sha256=sha, size=size, sample=sample)
@classmethod
def from_bytes(cls, data: bytes):
sha = get_file_hash(data)['file_hash']
return cls(size=len(data), sample=data[:512], sha256=sha)
@classmethod
def from_fileobj(cls, fileobj: BinaryIO):
fileobj_info: dict = get_file_hash(fileobj)
sample = fileobj.read(512)
return cls(
sha256=fileobj_info['file_hash'],
size=fileobj_info['file_size'],
sample=sample)
@dataclass
class CommitOperationAdd:
"""Data structure containing information about a file to be added to a commit."""
path_in_repo: str
path_or_fileobj: Union[str, Path, bytes, BinaryIO]
upload_info: UploadInfo = field(init=False, repr=False)
# Internal attributes
# set to "lfs" or "regular" once known
_upload_mode: Optional[UploadMode] = field(
init=False, repr=False, default=None)
# set to True if .gitignore rules prevent the file from being uploaded as LFS
# (server-side check)
_should_ignore: Optional[bool] = field(
init=False, repr=False, default=None)
# set to the remote OID of the file if it has already been uploaded
# useful to determine if a commit will be empty or not
_remote_oid: Optional[str] = field(init=False, repr=False, default=None)
# set to True once the file has been uploaded as LFS
_is_uploaded: bool = field(init=False, repr=False, default=False)
# set to True once the file has been committed
_is_committed: bool = field(init=False, repr=False, default=False)
def __post_init__(self) -> None:
"""Validates `path_or_fileobj` and compute `upload_info`."""
# Validate `path_or_fileobj` value
if isinstance(self.path_or_fileobj, Path):
self.path_or_fileobj = str(self.path_or_fileobj)
if isinstance(self.path_or_fileobj, str):
path_or_fileobj = os.path.normpath(
os.path.expanduser(self.path_or_fileobj))
if not os.path.isfile(path_or_fileobj):
raise ValueError(
f"Provided path: '{path_or_fileobj}' is not a file on the local file system"
)
elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
raise ValueError(
'path_or_fileobj must be either an instance of str, bytes or'
' io.BufferedIOBase. If you passed a file-like object, make sure it is'
' in binary mode.')
if isinstance(self.path_or_fileobj, io.BufferedIOBase):
try:
self.path_or_fileobj.tell()
self.path_or_fileobj.seek(0, os.SEEK_CUR)
except (OSError, AttributeError) as exc:
raise ValueError(
'path_or_fileobj is a file-like object but does not implement seek() and tell()'
) from exc
# Compute "upload_info" attribute
if isinstance(self.path_or_fileobj, str):
self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
elif isinstance(self.path_or_fileobj, bytes):
self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
else:
self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
@contextmanager
def as_file(self) -> Iterator[BinaryIO]:
"""
A context manager that yields a file-like object allowing to read the underlying
data behind `path_or_fileobj`.
"""
if isinstance(self.path_or_fileobj, str) or isinstance(
self.path_or_fileobj, Path):
with open(self.path_or_fileobj, 'rb') as file:
yield file
elif isinstance(self.path_or_fileobj, bytes):
yield io.BytesIO(self.path_or_fileobj)
elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
prev_pos = self.path_or_fileobj.tell()
yield self.path_or_fileobj
self.path_or_fileobj.seek(prev_pos, 0)
def b64content(self) -> bytes:
"""
The base64-encoded content of `path_or_fileobj`
Returns: `bytes`
"""
with self.as_file() as file:
return base64.b64encode(file.read())
@property
def _local_oid(self) -> Optional[str]:
"""Return the OID of the local file.
This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one.
If the file did not change, we won't upload it again to prevent empty commits.
For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref).
For regular files, the OID corresponds to the SHA1 of the file content.
Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1
of the pointer file content (not the actual file content). However, using the SHA256 is enough to detect
changes and more convenient client-side.
"""
if self._upload_mode is None:
return None
elif self._upload_mode == 'lfs':
return self.upload_info.sha256
else:
# Regular file => compute sha1
# => no need to read by chunk since the file is guaranteed to be <=5MB.
with self.as_file() as file:
return git_hash(file.read())
CommitOperation = Union[CommitOperationAdd, ]

View File

@@ -0,0 +1,71 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
from tqdm import tqdm
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
from modelscope.utils.logger import get_logger
logger = get_logger()
def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
disable_tqdm: bool = False,
tqdm_desc: str = None):
"""
A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
Args:
max_workers (int): The maximum number of threads to use.
disable_tqdm (bool): disable progress bar.
tqdm_desc (str): Desc of tqdm.
Returns:
function: A wrapped function that executes with threading and a progress bar.
Examples:
>>> from modelscope.utils.thread_utils import thread_executor
>>> import time
>>> @thread_executor(max_workers=8)
... def process_item(item, x, y):
... # do something to single item
... time.sleep(1)
... return str(item) + str(x) + str(y)
>>> items = [1, 2, 3]
>>> process_item(items, x='abc', y='xyz')
"""
def decorator(func):
@wraps(func)
def wrapper(iterable, *args, **kwargs):
results = []
# Create a tqdm progress bar with the total number of items to process
with tqdm(
unit_scale=True,
unit_divisor=1024,
initial=0,
total=len(iterable),
desc=tqdm_desc or f'Processing {len(iterable)} items',
disable=disable_tqdm,
) as pbar:
# Define a wrapper function to update the progress bar
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
futures = {
executor.submit(func, item, *args, **kwargs): item
for item in iterable
}
# Update the progress bar as tasks complete
for future in as_completed(futures):
pbar.update(1)
results.append(future.result())
return results
return wrapper
return decorator

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=3.0.0,<=3.0.1
datasets>=3.0.0,<=3.2.0
einops
oss2
Pillow

View File

@@ -1,8 +1,7 @@
addict
attrs
datasets>=3.0.0,<=3.0.1
datasets>=3.0.0,<=3.2.0
einops
oss2
Pillow
python-dateutil>=2.1
scipy

View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.hub.api import HubApi
from modelscope.utils.logger import get_logger
logger = get_logger()
logger.setLevel('DEBUG')
DEFAULT_GIT_PATH = 'git'
download_model_file_name = 'test.bin'
class FileExistsTest(unittest.TestCase):
def test_file_exists(self):
api = HubApi()
self.assertTrue(
api.file_exists('iic/gte_Qwen2-7B-instruct', 'added_tokens.json'))
self.assertTrue(
api.file_exists('iic/gte_Qwen2-7B-instruct',
'1_Pooling/config.json'))
if __name__ == '__main__':
unittest.main()

View File

@@ -311,6 +311,60 @@ class TestToOllama(unittest.TestCase):
'llama3.3')
_test_check_tmpl_type('bartowski/EXAONE-3.5-7.8B-Instruct-GGUF',
'exaone3.5')
_test_check_tmpl_type(
'QuantFactory/Tulu-3.1-8B-SuperNova-Smart-GGUF',
'tulu3',
gguf_meta={'general.name': 'Tulu 3.1 8B SuperNova'})
_test_check_tmpl_type(
'bartowski/Athene-V2-Chat-GGUF',
'athene-v2',
gguf_meta={'general.name': 'Athene V2 Chat'})
_test_check_tmpl_type(
'QuantFactory/granite-guardian-3.0-2b-GGUF',
'granite3-guardian',
gguf_meta={'general.name': 'Models'})
_test_check_tmpl_type('lmstudio-community/OpenCoder-8B-Instruct-GGUF',
'opencoder')
_test_check_tmpl_type(
'QuantFactory/SmolLM2-1.7B-Instruct-GGUF',
'smollm2',
gguf_meta={'general.name': 'Smollm2 1.7B 8k Mix7 Ep2 v2'})
_test_check_tmpl_type(
'prithivMLmods/Aya-Expanse-8B-GGUF',
'aya-expanse',
gguf_meta={'general.name': 'Aya Expanse 8b'})
_test_check_tmpl_type('lmstudio-community/Falcon3-7B-Instruct-GGUF',
'falcon3')
_test_check_tmpl_type(
'lmstudio-community/granite-3.1-8b-instruct-GGUF',
'granite3.1-dense',
gguf_meta={'general.name': 'Granite 3.1 8b Instruct'})
_test_check_tmpl_type(
'lmstudio-community/granite-3.1-2b-instruct-GGUF',
'granite3.1-dense',
gguf_meta={'general.name': 'Granite 3.1 2b Instruct'})
_test_check_tmpl_type(
'lmstudio-community/granite-embedding-278m-multilingual-GGUF',
'granite-embedding',
gguf_meta={'general.name': 'Granite Embedding 278m Multilingual'})
_test_check_tmpl_type(
'QuantFactory/granite-3.1-3b-a800m-instruct-GGUF',
'granite3.1-moe',
gguf_meta={'general.name': 'Granite 3.1 3b A800M Base'})
_test_check_tmpl_type(
'bartowski/granite-3.1-1b-a400m-instruct-GGUF',
'granite3.1-moe',
gguf_meta={'general.name': 'Granite 3.1 1b A400M Instruct'})
_test_check_tmpl_type(
'bartowski/SmallThinker-3B-Preview-GGUF',
'smallthinker',
gguf_meta={'general.name': 'SmallThinker 3B Preview'})
_test_check_tmpl_type(
'bartowski/Dolphin3.0-Llama3.1-8B-GGUF',
'dolphin3',
gguf_meta={'general.name': 'Dolphin 3.0 Llama 3.1 8B'})
_test_check_tmpl_type(
'AI-ModelScope/phi-4', 'phi4', gguf_meta={'general.name': 'Phi 4'})
if __name__ == '__main__':