mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Merge remote-tracking branch 'origin' into fix/trust_remote_code_
This commit is contained in:
@@ -65,6 +65,8 @@ RUN sh /tmp/install.sh {version_args} && \
|
||||
pip config set install.trusted-host mirrors.aliyun.com && \
|
||||
cp /tmp/resources/ubuntu2204.aliyun /etc/apt/sources.list
|
||||
|
||||
RUN pip install --no-cache-dir omegaconf==2.3.0 && pip cache purge
|
||||
|
||||
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
ENV VLLM_USE_MODELSCOPE=True
|
||||
ENV LMDEPLOY_USE_MODELSCOPE=True
|
||||
|
||||
@@ -56,7 +56,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForPreTraining, AutoModelForTextEncoding,
|
||||
AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
|
||||
T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel,
|
||||
LlamaPreTrainedModel, LlamaForCausalLM)
|
||||
LlamaPreTrainedModel, LlamaForCausalLM, hf_pipeline)
|
||||
else:
|
||||
print(
|
||||
'transformer is not installed, please install it if you want to use related modules'
|
||||
|
||||
@@ -40,9 +40,6 @@ def run_cmd():
|
||||
if not hasattr(args, 'func'):
|
||||
parser.print_help()
|
||||
exit(1)
|
||||
if args.token is not None:
|
||||
api = HubApi()
|
||||
api.login(args.token)
|
||||
cmd = args.func(args)
|
||||
cmd.execute()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
|
||||
from modelscope.hub.file_download import (dataset_file_download,
|
||||
model_file_download)
|
||||
@@ -54,16 +55,21 @@ class DownloadCMD(CLICommand):
|
||||
default='model',
|
||||
help="Type of repo to download from (defaults to 'model').",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--token',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Optional. Access token to download controlled entities.')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Revision of the model.')
|
||||
help='Revision of the entity (e.g., model).')
|
||||
parser.add_argument(
|
||||
'--cache_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Cache directory to save model.')
|
||||
help='Cache directory to save entity (e.g., model).')
|
||||
parser.add_argument(
|
||||
'--local_dir',
|
||||
type=str,
|
||||
@@ -118,6 +124,10 @@ class DownloadCMD(CLICommand):
|
||||
% self.args.repo_type)
|
||||
if not self.args.model and not self.args.dataset:
|
||||
raise Exception('Model or dataset must be set.')
|
||||
cookies = None
|
||||
if self.args.token is not None:
|
||||
api = HubApi()
|
||||
cookies = api.get_cookies(access_token=self.args.token)
|
||||
if self.args.model:
|
||||
if len(self.args.files) == 1: # download single file
|
||||
model_file_download(
|
||||
@@ -125,7 +135,8 @@ class DownloadCMD(CLICommand):
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
revision=self.args.revision,
|
||||
cookies=cookies)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
snapshot_download(
|
||||
@@ -135,7 +146,7 @@ class DownloadCMD(CLICommand):
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
cookies=cookies)
|
||||
else: # download repo
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
@@ -145,7 +156,7 @@ class DownloadCMD(CLICommand):
|
||||
allow_file_pattern=convert_patterns(self.args.include),
|
||||
ignore_file_pattern=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
cookies=cookies)
|
||||
elif self.args.dataset:
|
||||
dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
|
||||
if len(self.args.files) == 1: # download single file
|
||||
@@ -154,7 +165,8 @@ class DownloadCMD(CLICommand):
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=dataset_revision)
|
||||
revision=dataset_revision,
|
||||
cookies=cookies)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
dataset_snapshot_download(
|
||||
@@ -164,7 +176,7 @@ class DownloadCMD(CLICommand):
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
cookies=cookies)
|
||||
else: # download repo
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
@@ -174,6 +186,6 @@ class DownloadCMD(CLICommand):
|
||||
allow_file_pattern=convert_patterns(self.args.include),
|
||||
ignore_file_pattern=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
cookies=cookies)
|
||||
else:
|
||||
pass # noop
|
||||
|
||||
@@ -91,7 +91,7 @@ class UploadCMD(CLICommand):
|
||||
'--endpoint',
|
||||
type=str,
|
||||
default=get_endpoint(),
|
||||
help='Endpoint for Modelscope service.')
|
||||
help='Endpoint for ModelScope service.')
|
||||
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
@@ -137,14 +137,15 @@ class UploadCMD(CLICommand):
|
||||
|
||||
# Check token and login
|
||||
# The cookies will be reused if the user has logged in before.
|
||||
cookies = None
|
||||
api = HubApi(endpoint=self.args.endpoint)
|
||||
|
||||
if self.args.token:
|
||||
api.login(access_token=self.args.token)
|
||||
cookies = api.get_cookies(access_token=self.args.token)
|
||||
else:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError(
|
||||
'The `token` is not provided! '
|
||||
'No credential found for entity upload. '
|
||||
'You can pass the `--token` argument, '
|
||||
'or use api.login(access_token=`your_sdk_token`). '
|
||||
'Your token is available at https://modelscope.cn/my/myaccesstoken'
|
||||
@@ -158,6 +159,7 @@ class UploadCMD(CLICommand):
|
||||
repo_type=self.args.repo_type,
|
||||
commit_message=self.args.commit_message,
|
||||
commit_description=self.args.commit_description,
|
||||
token=self.args.token,
|
||||
)
|
||||
elif os.path.isdir(self.local_path):
|
||||
api.upload_folder(
|
||||
@@ -170,6 +172,7 @@ class UploadCMD(CLICommand):
|
||||
allow_patterns=convert_patterns(self.args.include),
|
||||
ignore_patterns=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
token=self.args.token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'{self.local_path} is not a valid local path')
|
||||
|
||||
@@ -34,6 +34,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_CREDENTIALS_PATH,
|
||||
DEFAULT_MAX_WORKERS,
|
||||
DEFAULT_MODELSCOPE_DOMAIN,
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT,
|
||||
MODELSCOPE_CLOUD_USERNAME,
|
||||
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
|
||||
@@ -112,9 +113,19 @@ class HubApi:
|
||||
|
||||
self.upload_checker = UploadingCheck()
|
||||
|
||||
def get_cookies(self, access_token):
|
||||
from requests.cookies import RequestsCookieJar
|
||||
jar = RequestsCookieJar()
|
||||
jar.set('m_session_id',
|
||||
access_token,
|
||||
domain=os.getenv('MODELSCOPE_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_DOMAIN),
|
||||
path='/')
|
||||
return jar
|
||||
|
||||
def login(
|
||||
self,
|
||||
access_token: Optional[str] = None,
|
||||
access_token: Optional[str] = None
|
||||
):
|
||||
"""Login with your SDK access token, which can be obtained from
|
||||
https://www.modelscope.cn user center.
|
||||
|
||||
@@ -71,6 +71,10 @@ def check_local_model_is_latest(
|
||||
headers=snapshot_header,
|
||||
use_cookies=cookies,
|
||||
)
|
||||
model_cache = None
|
||||
# download via non-git method
|
||||
if not os.path.exists(os.path.join(model_root_path, '.git')):
|
||||
model_cache = ModelFileSystemCache(model_root_path)
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
@@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton):
|
||||
git_env = os.environ.copy()
|
||||
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
||||
command = [self.git_path, *args]
|
||||
command = [item for item in command if item]
|
||||
response = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
|
||||
@@ -10,6 +10,8 @@ from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
|
||||
ThirdParty)
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.import_utils import is_transformers_available
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
@@ -17,6 +19,7 @@ from .base import Pipeline
|
||||
from .util import is_official_hub_path
|
||||
|
||||
PIPELINES = Registry('pipelines')
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def normalize_model_input(model,
|
||||
@@ -72,7 +75,7 @@ def pipeline(task: str = None,
|
||||
config_file: str = None,
|
||||
pipeline_name: str = None,
|
||||
framework: str = None,
|
||||
device: str = 'gpu',
|
||||
device: str = None,
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
ignore_file_pattern: List[str] = None,
|
||||
**kwargs) -> Pipeline:
|
||||
@@ -109,6 +112,7 @@ def pipeline(task: str = None,
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
|
||||
pipeline_props = None
|
||||
if pipeline_name is None:
|
||||
# get default pipeline for this task
|
||||
if isinstance(model, str) \
|
||||
@@ -157,8 +161,11 @@ def pipeline(task: str = None,
|
||||
if pipeline_name:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
else:
|
||||
try:
|
||||
check_config(cfg)
|
||||
pipeline_props = cfg.pipeline
|
||||
except AssertionError as e:
|
||||
logger.info(str(e))
|
||||
|
||||
elif model is not None:
|
||||
# get pipeline info from Model object
|
||||
@@ -166,8 +173,12 @@ def pipeline(task: str = None,
|
||||
if not hasattr(first_model, 'pipeline'):
|
||||
# model is instantiated by user, we should parse config again
|
||||
cfg = read_config(first_model.model_dir)
|
||||
try:
|
||||
check_config(cfg)
|
||||
first_model.pipeline = cfg.pipeline
|
||||
except AssertionError as e:
|
||||
logger.info(str(e))
|
||||
if first_model.__dict__.get('pipeline'):
|
||||
pipeline_props = first_model.pipeline
|
||||
else:
|
||||
pipeline_name, default_model_repo = get_default_pipeline_info(task)
|
||||
@@ -176,6 +187,23 @@ def pipeline(task: str = None,
|
||||
else:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
|
||||
if not pipeline_props and is_transformers_available():
|
||||
try:
|
||||
from modelscope.utils.hf_util import hf_pipeline
|
||||
return hf_pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
framework=framework,
|
||||
device=device,
|
||||
**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,'
|
||||
' but that also failed.')
|
||||
raise e
|
||||
|
||||
if not device:
|
||||
device = 'gpu'
|
||||
pipeline_props['model'] = model
|
||||
pipeline_props['device'] = device
|
||||
cfg = ConfigDict(pipeline_props)
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = get_logger()
|
||||
|
||||
class FaceProcessingBasePipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
def __init__(self, model: str, use_det=True, **kwargs):
|
||||
"""
|
||||
use `model` to create a face processing pipeline and output cropped img, scores, bbox and lmks.
|
||||
|
||||
@@ -30,8 +30,10 @@ class FaceProcessingBasePipeline(Pipeline):
|
||||
model: model id on modelscope hub.
|
||||
|
||||
"""
|
||||
self.use_det = use_det
|
||||
super().__init__(model=model, **kwargs)
|
||||
# face detect pipeline
|
||||
if use_det:
|
||||
det_model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
|
||||
self.face_detection = pipeline(
|
||||
Tasks.face_detection, model=det_model_id)
|
||||
@@ -94,6 +96,7 @@ class FaceProcessingBasePipeline(Pipeline):
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
img = img[:, :, ::-1]
|
||||
if self.use_det:
|
||||
det_result = self.face_detection(img.copy())
|
||||
rtn = self._choose_face(det_result, img_shape=img.shape)
|
||||
if rtn is not None:
|
||||
@@ -109,6 +112,11 @@ class FaceProcessingBasePipeline(Pipeline):
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
result = {}
|
||||
resized_img = cv2.resize(img, (112, 112))
|
||||
result['img'] = np.ascontiguousarray(resized_img)
|
||||
return result
|
||||
|
||||
def align_face_padding(self, img, rect, padding_size=16, pad_pixel=127):
|
||||
rect = np.reshape(rect, (-1, 4))
|
||||
|
||||
@@ -14,10 +14,11 @@ from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.cv.face_processing_base_pipeline import \
|
||||
FaceProcessingBasePipeline
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from . import FaceProcessingBasePipeline
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -26,15 +27,14 @@ logger = get_logger()
|
||||
Tasks.face_recognition, module_name=Pipelines.face_recognition)
|
||||
class FaceRecognitionPipeline(FaceProcessingBasePipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
def __init__(self, model: str, use_det=True, **kwargs):
|
||||
"""
|
||||
use `model` to create a face recognition pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
|
||||
# face recong model
|
||||
super().__init__(model=model, **kwargs)
|
||||
super().__init__(model=model, use_det=use_det, **kwargs)
|
||||
device = torch.device(
|
||||
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device
|
||||
|
||||
@@ -18,6 +18,7 @@ class TemplateInfo:
|
||||
template: str = None
|
||||
template_regex: str = None
|
||||
modelfile_prefix: str = None
|
||||
allow_general_name: bool = True
|
||||
|
||||
|
||||
def cases(*names):
|
||||
@@ -255,6 +256,12 @@ template_info = [
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi3',
|
||||
),
|
||||
TemplateInfo(
|
||||
template_regex=
|
||||
f'.*{cases("phi4-mini", "phi-4-mini")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi4-mini',
|
||||
),
|
||||
TemplateInfo(
|
||||
template_regex=
|
||||
f'.*{cases("phi4", "phi-4")}{no_multi_modal()}.*',
|
||||
@@ -470,6 +477,12 @@ template_info = [
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/command-r-plus',
|
||||
),
|
||||
TemplateInfo(
|
||||
template_regex=
|
||||
f'.*{cases("command-r7b-arabic")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/command-r7b-arabic',
|
||||
),
|
||||
TemplateInfo(
|
||||
template_regex=
|
||||
f'.*{cases("command-r7b")}.*',
|
||||
@@ -666,6 +679,14 @@ template_info = [
|
||||
template_regex=f'.*{cases("granite")}.*{cases("code")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite-code'),
|
||||
TemplateInfo(
|
||||
template_regex=f'.*{cases("granite")}.*{cases("vision")}.*{cases("3.2")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.2-vision'),
|
||||
TemplateInfo(
|
||||
template_regex=f'.*{cases("granite")}.*{cases("3.2")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.2'),
|
||||
TemplateInfo(
|
||||
template_regex=f'.*{cases("granite-3.1")}.*{cases("2b", "8b")}.*',
|
||||
modelfile_prefix=
|
||||
@@ -733,6 +754,12 @@ template_info = [
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/smallthinker'),
|
||||
|
||||
TemplateInfo(
|
||||
template_regex=f'.*{cases("openthinker")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/openthinker',
|
||||
allow_general_name=False),
|
||||
|
||||
TemplateInfo(
|
||||
template_regex=
|
||||
f'.*{cases("olmo2", "olmo-2")}.*',
|
||||
@@ -888,8 +915,14 @@ template_info = [
|
||||
template_regex=f'.*{cases("exaone")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/exaone3.5'),
|
||||
|
||||
|
||||
TemplateInfo(
|
||||
template_regex=f'.*{cases("r1-1776")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/r1-1776'),
|
||||
TemplateInfo(
|
||||
template_regex=f'.*{cases("deepscaler")}.*',
|
||||
modelfile_prefix=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepscaler'),
|
||||
]
|
||||
|
||||
|
||||
@@ -1015,33 +1048,50 @@ class TemplateLoader:
|
||||
f'Please make sure you model_id: {model_id} '
|
||||
f'and template_name: {template_name} is supported.')
|
||||
logger.info('Exporting to ollama:')
|
||||
names = []
|
||||
names = {}
|
||||
match_infos = {}
|
||||
if gguf_meta:
|
||||
gguf_header_name = gguf_meta.get("general.name", None)
|
||||
names.append(gguf_header_name)
|
||||
if gguf_header_name:
|
||||
names['gguf_header_name'] = gguf_header_name
|
||||
if model_id:
|
||||
names.append(model_id)
|
||||
for name in names:
|
||||
names['model_id'] = model_id
|
||||
for name_type, name in names.items():
|
||||
for _info in template_info:
|
||||
if re.fullmatch(_info.template_regex, name):
|
||||
if _info.modelfile_prefix and not kwargs.get('ignore_oss_model_file', False):
|
||||
match_infos[name_type] = name, _info
|
||||
break
|
||||
|
||||
_name = None
|
||||
_info = None
|
||||
if len(match_infos) == 1:
|
||||
_, (_name, _info) = match_infos.popitem()
|
||||
elif len(match_infos) > 1:
|
||||
if not match_infos['model_id'][1].allow_general_name:
|
||||
_name, _info = match_infos['model_id']
|
||||
else:
|
||||
_name, _info = match_infos['gguf_header_name']
|
||||
|
||||
if _info:
|
||||
template_str = TemplateLoader._read_content_from_url(
|
||||
_info.modelfile_prefix + '.template')
|
||||
if not template_str:
|
||||
logger.info(f'{name} has no template file.')
|
||||
logger.info(f'{_name} has no template file.')
|
||||
params = TemplateLoader._read_content_from_url(_info.modelfile_prefix + '.params')
|
||||
if params:
|
||||
params = json.loads(params)
|
||||
else:
|
||||
logger.info(f'{name} has no params file.')
|
||||
logger.info(f'{_name} has no params file.')
|
||||
license = TemplateLoader._read_content_from_url(
|
||||
_info.modelfile_prefix + '.license')
|
||||
if not template_str:
|
||||
logger.info(f'{name} has no license file.')
|
||||
logger.info(f'{_name} has no license file.')
|
||||
format_out = TemplateLoader._format_return(template_str, params, split, license)
|
||||
if debug:
|
||||
return format_out, _info
|
||||
return format_out
|
||||
|
||||
if template_name:
|
||||
template = TemplateLoader.load_by_template_name(
|
||||
template_name, **kwargs)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .auto_class import *
|
||||
from .patcher import patch_context, patch_hub, unpatch_hub
|
||||
from .pipeline_builder import hf_pipeline
|
||||
|
||||
@@ -27,7 +27,8 @@ def get_all_imported_modules():
|
||||
transformers_include_names = [
|
||||
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
|
||||
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
|
||||
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast'
|
||||
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast',
|
||||
'Pipeline'
|
||||
]
|
||||
peft_include_names = ['.*PeftModel.*', '.*Config']
|
||||
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
|
||||
@@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
model_dir, *model_args, **kwargs)
|
||||
return module_obj
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
push_to_hub = kwargs.pop('push_to_hub', False)
|
||||
if push_to_hub:
|
||||
from modelscope.hub.push_to_hub import push_to_hub
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.repository import Repository
|
||||
|
||||
token = kwargs.get('token')
|
||||
commit_message = kwargs.pop('commit_message', None)
|
||||
repo_name = kwargs.pop(
|
||||
'repo_id',
|
||||
save_directory.split(os.path.sep)[-1])
|
||||
|
||||
api = HubApi()
|
||||
api.login(token)
|
||||
api.create_repo(repo_name)
|
||||
# clone the repo
|
||||
Repository(save_directory, repo_name)
|
||||
|
||||
super().save_pretrained(
|
||||
save_directory=save_directory,
|
||||
safe_serialization=safe_serialization,
|
||||
push_to_hub=False,
|
||||
**kwargs)
|
||||
|
||||
# Class members may be unpatched, so push_to_hub is done separately here
|
||||
if push_to_hub:
|
||||
push_to_hub(
|
||||
repo_name=repo_name,
|
||||
output_dir=save_directory,
|
||||
commit_message=commit_message,
|
||||
token=token)
|
||||
|
||||
if not hasattr(module_class, 'from_pretrained'):
|
||||
del ClassWrapper.from_pretrained
|
||||
else:
|
||||
@@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
if not hasattr(module_class, 'get_config_dict'):
|
||||
del ClassWrapper.get_config_dict
|
||||
|
||||
if not hasattr(module_class, 'save_pretrained'):
|
||||
del ClassWrapper.save_pretrained
|
||||
|
||||
ClassWrapper.__name__ = module_class.__name__
|
||||
ClassWrapper.__qualname__ = module_class.__qualname__
|
||||
return ClassWrapper
|
||||
@@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
has_save_pretrained = hasattr(var, 'save_pretrained')
|
||||
except: # noqa
|
||||
continue
|
||||
|
||||
if wrap:
|
||||
# save_pretrained is not a classmethod and cannot be overridden by replacing
|
||||
# the class method. It requires replacing the class object method.
|
||||
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
|
||||
try:
|
||||
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
|
||||
if (not has_from_pretrained and not has_get_config_dict
|
||||
and not has_get_peft_type and not has_save_pretrained):
|
||||
all_available_modules.append(var)
|
||||
else:
|
||||
all_available_modules.append(
|
||||
|
||||
54
modelscope/utils/hf_util/pipeline_builder.py
Normal file
54
modelscope/utils/hf_util/pipeline_builder.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Pipeline as PipelineHF
|
||||
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
|
||||
from transformers.pipelines import check_task, get_task
|
||||
|
||||
from modelscope.hub import snapshot_download
|
||||
from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub
|
||||
|
||||
|
||||
def _get_hf_device(device):
|
||||
if isinstance(device, str):
|
||||
device_name = device.lower()
|
||||
eles = device_name.split(':')
|
||||
if eles[0] == 'gpu':
|
||||
eles = ['cuda'] + eles[1:]
|
||||
device = ''.join(eles)
|
||||
return device
|
||||
|
||||
|
||||
def _get_hf_pipeline_class(task, model):
|
||||
if not task:
|
||||
task = get_task(model)
|
||||
normalized_task, targeted_task, task_options = check_task(task)
|
||||
pipeline_class = targeted_task['impl']
|
||||
pipeline_class = _patch_pretrained_class([pipeline_class])[0]
|
||||
return pipeline_class
|
||||
|
||||
|
||||
def hf_pipeline(
|
||||
task: str = None,
|
||||
model: Optional[Union[str, 'PreTrainedModel', 'TFPreTrainedModel']] = None,
|
||||
framework: Optional[str] = None,
|
||||
device: Optional[Union[int, str, 'torch.device']] = None,
|
||||
**kwargs,
|
||||
) -> PipelineHF:
|
||||
if isinstance(model, str):
|
||||
if not os.path.exists(model):
|
||||
model = snapshot_download(model)
|
||||
|
||||
framework = 'pt' if framework == 'pytorch' else framework
|
||||
|
||||
device = _get_hf_device(device)
|
||||
pipeline_class = _get_hf_pipeline_class(task, model)
|
||||
|
||||
return pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
framework=framework,
|
||||
device=device,
|
||||
pipeline_class=pipeline_class,
|
||||
**kwargs)
|
||||
@@ -220,7 +220,7 @@ class RepoUtils:
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitInfo(str):
|
||||
class CommitInfo:
|
||||
"""Data structure containing information about a newly created commit.
|
||||
|
||||
Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`],
|
||||
@@ -240,46 +240,12 @@ class CommitInfo(str):
|
||||
oid (`str`):
|
||||
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
|
||||
|
||||
pr_url (`str`, *optional*):
|
||||
Url to the PR that has been created, if any. Populated when `create_pr=True`
|
||||
is passed.
|
||||
|
||||
pr_revision (`str`, *optional*):
|
||||
Revision of the PR that has been created, if any. Populated when
|
||||
`create_pr=True` is passed. Example: `"refs/pr/1"`.
|
||||
|
||||
pr_num (`int`, *optional*):
|
||||
Number of the PR discussion that has been created, if any. Populated when
|
||||
`create_pr=True` is passed. Can be passed as `discussion_num` in
|
||||
[`get_discussion_details`]. Example: `1`.
|
||||
|
||||
_url (`str`, *optional*):
|
||||
Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by
|
||||
[`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on
|
||||
the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this
|
||||
attribute. Please use `commit_url` instead.
|
||||
"""
|
||||
|
||||
commit_url: str
|
||||
commit_message: str
|
||||
commit_description: str
|
||||
oid: str
|
||||
pr_url: Optional[str] = None
|
||||
|
||||
# Computed from `pr_url` in `__post_init__`
|
||||
pr_revision: Optional[str] = field(init=False)
|
||||
pr_num: Optional[str] = field(init=False)
|
||||
|
||||
# legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.)
|
||||
_url: str = field(
|
||||
repr=False, default=None) # type: ignore # defaults to `commit_url`
|
||||
|
||||
def __new__(cls,
|
||||
*args,
|
||||
commit_url: str,
|
||||
_url: Optional[str] = None,
|
||||
**kwargs):
|
||||
return str.__new__(cls, _url or commit_url)
|
||||
|
||||
def to_dict(cls):
|
||||
return {
|
||||
@@ -287,7 +253,6 @@ class CommitInfo(str):
|
||||
'commit_message': cls.commit_message,
|
||||
'commit_description': cls.commit_description,
|
||||
'oid': cls.oid,
|
||||
'pr_url': cls.pr_url,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ class FaceRecognitionTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_face_compare(self):
|
||||
img1 = 'data/test/images/face_recognition_1.png'
|
||||
img2 = 'data/test/images/face_recognition_2.png'
|
||||
img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png'
|
||||
img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png'
|
||||
|
||||
face_recognition = pipeline(
|
||||
Tasks.face_recognition, model=self.model_id)
|
||||
@@ -27,6 +27,30 @@ class FaceRecognitionTest(unittest.TestCase):
|
||||
sim = np.dot(emb1[0], emb2[0])
|
||||
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_face_compare_use_det(self):
|
||||
img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png'
|
||||
img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png'
|
||||
|
||||
face_recognition = pipeline(
|
||||
Tasks.face_recognition, model=self.model_id, use_det=True)
|
||||
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING]
|
||||
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING]
|
||||
sim = np.dot(emb1[0], emb2[0])
|
||||
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_face_compare_not_use_det(self):
|
||||
img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png'
|
||||
img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png'
|
||||
|
||||
face_recognition = pipeline(
|
||||
Tasks.face_recognition, model=self.model_id, use_det=False)
|
||||
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING]
|
||||
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING]
|
||||
sim = np.dot(emb1[0], emb2[0])
|
||||
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -122,6 +122,36 @@ class TestToOllama(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_check_template_type(self):
|
||||
_test_check_tmpl_type(
|
||||
'DevQuasar/CohereForAI.c4ai-command-r7b-arabic-02-2025-GGUF',
|
||||
'command-r7b-arabic',
|
||||
gguf_meta={
|
||||
'general.name': 'CohereForAI.c4ai Command R7B Arabic 02 2025'
|
||||
})
|
||||
_test_check_tmpl_type(
|
||||
'lmstudio-community/granite-vision-3.2-2b-GGUF',
|
||||
'granite3.2-vision',
|
||||
gguf_meta={'general.name': 'Granite Vision 3.2 2b'})
|
||||
_test_check_tmpl_type(
|
||||
'unsloth/Phi-4-mini-instruct-GGUF',
|
||||
'phi4-mini',
|
||||
gguf_meta={'general.name': 'Phi 4 Mini Instruct'})
|
||||
_test_check_tmpl_type(
|
||||
'lmstudio-community/granite-3.2-2b-instruct-GGUF',
|
||||
'granite3.2',
|
||||
gguf_meta={'general.name': 'Granite 3.2 2b Instruct'})
|
||||
_test_check_tmpl_type(
|
||||
'unsloth/r1-1776-GGUF',
|
||||
'r1-1776',
|
||||
gguf_meta={'general.name': 'R1 1776'})
|
||||
_test_check_tmpl_type(
|
||||
'QuantFactory/DeepScaleR-1.5B-Preview-GGUF',
|
||||
'deepscaler',
|
||||
gguf_meta={'general.name': 'DeepScaleR 1.5B Preview'})
|
||||
_test_check_tmpl_type(
|
||||
'lmstudio-community/OpenThinker-32B-GGUF',
|
||||
'openthinker',
|
||||
gguf_meta={'general.name': 'Qwen2.5 7B Instruct'})
|
||||
_test_check_tmpl_type(
|
||||
'LLM-Research/Llama-3.3-70B-Instruct',
|
||||
'llama3.3',
|
||||
|
||||
@@ -40,6 +40,14 @@ class HFUtilTest(unittest.TestCase):
|
||||
with open(self.test_file2, 'w') as f:
|
||||
f.write('{}')
|
||||
|
||||
self.pipeline_qa_context = r"""
|
||||
Extractive Question Answering is the task of extracting an answer from a text given a question. An example
|
||||
of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would
|
||||
like to fine-tune a model on a SQuAD task, you may leverage the
|
||||
examples/pytorch/question-answering/run_squad.py script.
|
||||
"""
|
||||
self.pipeline_qa_question = 'What is a good example of a question answering dataset?'
|
||||
|
||||
def tearDown(self):
|
||||
logger.info('TearDown')
|
||||
shutil.rmtree(self.model_dir, ignore_errors=True)
|
||||
@@ -235,6 +243,59 @@ class HFUtilTest(unittest.TestCase):
|
||||
'Qwen/Qwen1.5-0.5B-Chat', trust_remote_code=True)
|
||||
model.push_to_hub(self.create_model_name)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_model_id(self):
|
||||
from modelscope import pipeline
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
qa = pipeline('question-answering', model=model_id)
|
||||
assert qa(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_auto_model(self):
|
||||
from modelscope import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
qa = pipeline('question-answering', model=model, tokenizer=tokenizer)
|
||||
assert qa(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_save_pretrained(self):
|
||||
from modelscope import pipeline
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
|
||||
pipe_ori = pipeline('question-answering', model=model_id)
|
||||
|
||||
result_ori = pipe_ori(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
# save_pretrained
|
||||
repo_id = self.create_model_name
|
||||
save_dir = './tmp_test_hf_pipeline'
|
||||
try:
|
||||
os.system(f'rm -rf {save_dir}')
|
||||
self.api.delete_model(repo_id)
|
||||
# wait for delete repo
|
||||
import time
|
||||
time.sleep(5)
|
||||
except Exception:
|
||||
# if repo not exists
|
||||
pass
|
||||
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
|
||||
|
||||
# load from saved
|
||||
pipe_new = pipeline('question-answering', model=repo_id)
|
||||
result_new = pipe_new(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
assert result_new == result_ori
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user