Merge remote-tracking branch 'origin' into fallback

This commit is contained in:
Yingda Chen
2025-03-11 14:44:03 +08:00
9 changed files with 181 additions and 96 deletions

View File

@@ -26,8 +26,12 @@ on:
other_params:
description: 'Other params in --xxx xxx'
required: false
python_version:
description: 'Python version to use, default is 3.10.14'
required: false
default: '3.10.14'
run-name: Docker-${{ inputs.modelscope_branch }}-${{ inputs.image_type }}-${{ inputs.workflow_name }}-by-@${{ github.actor }}
run-name: Docker-${{ inputs.modelscope_branch }}-${{ inputs.image_type }}-${{ inputs.workflow_name }}-${{ inputs.python_version }}-by-@${{ github.actor }}
jobs:
build:
@@ -51,4 +55,4 @@ jobs:
run: |
set -e
source ~/.bashrc
python docker/build_image.py --image_type ${{ github.event.inputs.image_type }} --modelscope_branch ${{ github.event.inputs.modelscope_branch }} --modelscope_version ${{ github.event.inputs.modelscope_version }} --swift_branch ${{ github.event.inputs.swift_branch }} --ci_image ${{ github.event.inputs.ci_image }} ${{ github.event.inputs.other_params }}
python docker/build_image.py --image_type ${{ github.event.inputs.image_type }} --python_version ${{ github.event.inputs.python_version }} --modelscope_branch ${{ github.event.inputs.modelscope_branch }} --modelscope_version ${{ github.event.inputs.modelscope_version }} --swift_branch ${{ github.event.inputs.swift_branch }} --ci_image ${{ github.event.inputs.ci_image }} ${{ github.event.inputs.other_params }}

View File

@@ -22,7 +22,7 @@ logger = get_logger()
class FaceProcessingBasePipeline(Pipeline):
def __init__(self, model: str, **kwargs):
def __init__(self, model: str, use_det=True, **kwargs):
"""
use `model` to create a face processing pipeline and output cropped img, scores, bbox and lmks.
@@ -30,11 +30,13 @@ class FaceProcessingBasePipeline(Pipeline):
model: model id on modelscope hub.
"""
self.use_det = use_det
super().__init__(model=model, **kwargs)
# face detect pipeline
det_model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
self.face_detection = pipeline(
Tasks.face_detection, model=det_model_id)
if use_det:
det_model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd'
self.face_detection = pipeline(
Tasks.face_detection, model=det_model_id)
def _choose_face(self,
det_result,
@@ -94,21 +96,27 @@ class FaceProcessingBasePipeline(Pipeline):
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img[:, :, ::-1]
det_result = self.face_detection(img.copy())
rtn = self._choose_face(det_result, img_shape=img.shape)
if rtn is not None:
scores, bboxes, face_lmks = rtn
face_lmks = face_lmks.reshape(5, 2)
align_img, _ = align_face(img, (112, 112), face_lmks)
if self.use_det:
det_result = self.face_detection(img.copy())
rtn = self._choose_face(det_result, img_shape=img.shape)
if rtn is not None:
scores, bboxes, face_lmks = rtn
face_lmks = face_lmks.reshape(5, 2)
align_img, _ = align_face(img, (112, 112), face_lmks)
result = {}
result['img'] = np.ascontiguousarray(align_img)
result['scores'] = [scores]
result['bbox'] = bboxes
result['lmks'] = face_lmks
return result
result = {}
result['img'] = np.ascontiguousarray(align_img)
result['scores'] = [scores]
result['bbox'] = bboxes
result['lmks'] = face_lmks
return result
else:
return None
else:
return None
result = {}
resized_img = cv2.resize(img, (112, 112))
result['img'] = np.ascontiguousarray(resized_img)
return result
def align_face_padding(self, img, rect, padding_size=16, pad_pixel=127):
rect = np.reshape(rect, (-1, 4))

View File

@@ -14,10 +14,11 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.cv.face_processing_base_pipeline import \
FaceProcessingBasePipeline
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from . import FaceProcessingBasePipeline
logger = get_logger()
@@ -26,15 +27,14 @@ logger = get_logger()
Tasks.face_recognition, module_name=Pipelines.face_recognition)
class FaceRecognitionPipeline(FaceProcessingBasePipeline):
def __init__(self, model: str, **kwargs):
def __init__(self, model: str, use_det=True, **kwargs):
"""
use `model` to create a face recognition pipeline for prediction
Args:
model: model id on modelscope hub.
"""
# face recong model
super().__init__(model=model, **kwargs)
super().__init__(model=model, use_det=use_det, **kwargs)
device = torch.device(
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
self.device = device

View File

@@ -18,6 +18,7 @@ class TemplateInfo:
template: str = None
template_regex: str = None
modelfile_prefix: str = None
allow_general_name: bool = True
def cases(*names):
@@ -255,6 +256,12 @@ template_info = [
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi3',
),
TemplateInfo(
template_regex=
f'.*{cases("phi4-mini", "phi-4-mini")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi4-mini',
),
TemplateInfo(
template_regex=
f'.*{cases("phi4", "phi-4")}{no_multi_modal()}.*',
@@ -470,6 +477,12 @@ template_info = [
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/command-r-plus',
),
TemplateInfo(
template_regex=
f'.*{cases("command-r7b-arabic")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/command-r7b-arabic',
),
TemplateInfo(
template_regex=
f'.*{cases("command-r7b")}.*',
@@ -663,9 +676,17 @@ template_info = [
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3-guardian'),
TemplateInfo(
template_regex=f'.*{cases("granite")}.*{cases("code")}.*',
template_regex=f'.*{cases("granite")}.*{cases("code")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite-code'),
TemplateInfo(
template_regex=f'.*{cases("granite")}.*{cases("vision")}.*{cases("3.2")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.2-vision'),
TemplateInfo(
template_regex=f'.*{cases("granite")}.*{cases("3.2")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.2'),
TemplateInfo(
template_regex=f'.*{cases("granite-3.1")}.*{cases("2b", "8b")}.*',
modelfile_prefix=
@@ -733,6 +754,12 @@ template_info = [
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/smallthinker'),
TemplateInfo(
template_regex=f'.*{cases("openthinker")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/openthinker',
allow_general_name=False),
TemplateInfo(
template_regex=
f'.*{cases("olmo2", "olmo-2")}.*',
@@ -888,8 +915,14 @@ template_info = [
template_regex=f'.*{cases("exaone")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/exaone3.5'),
TemplateInfo(
template_regex=f'.*{cases("r1-1776")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/r1-1776'),
TemplateInfo(
template_regex=f'.*{cases("deepscaler")}.*',
modelfile_prefix=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepscaler'),
]
@@ -1014,33 +1047,50 @@ class TemplateLoader:
f'Please make sure you model_id: {model_id} '
f'and template_name: {template_name} is supported.')
logger.info('Exporting to ollama:')
names = []
names = {}
match_infos = {}
if gguf_meta:
gguf_header_name = gguf_meta.get("general.name", None)
names.append(gguf_header_name)
if gguf_header_name:
names['gguf_header_name'] = gguf_header_name
if model_id:
names.append(model_id)
for name in names:
names['model_id'] = model_id
for name_type, name in names.items():
for _info in template_info:
if re.fullmatch(_info.template_regex, name):
if _info.modelfile_prefix and not kwargs.get('ignore_oss_model_file', False):
template_str = TemplateLoader._read_content_from_url(
_info.modelfile_prefix + '.template')
if not template_str:
logger.info(f'{name} has no template file.')
params = TemplateLoader._read_content_from_url(_info.modelfile_prefix + '.params')
if params:
params = json.loads(params)
else:
logger.info(f'{name} has no params file.')
license = TemplateLoader._read_content_from_url(
_info.modelfile_prefix + '.license')
if not template_str:
logger.info(f'{name} has no license file.')
format_out = TemplateLoader._format_return(template_str, params, split, license)
if debug:
return format_out, _info
return format_out
match_infos[name_type] = name, _info
break
_name = None
_info = None
if len(match_infos) == 1:
_, (_name, _info) = match_infos.popitem()
elif len(match_infos) > 1:
if not match_infos['model_id'][1].allow_general_name:
_name, _info = match_infos['model_id']
else:
_name, _info = match_infos['gguf_header_name']
if _info:
template_str = TemplateLoader._read_content_from_url(
_info.modelfile_prefix + '.template')
if not template_str:
logger.info(f'{_name} has no template file.')
params = TemplateLoader._read_content_from_url(_info.modelfile_prefix + '.params')
if params:
params = json.loads(params)
else:
logger.info(f'{_name} has no params file.')
license = TemplateLoader._read_content_from_url(
_info.modelfile_prefix + '.license')
if not template_str:
logger.info(f'{_name} has no license file.')
format_out = TemplateLoader._format_return(template_str, params, split, license)
if debug:
return format_out, _info
return format_out
if template_name:
template = TemplateLoader.load_by_template_name(
template_name, **kwargs)

View File

@@ -345,7 +345,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
else:
all_available_modules.append(
get_wrapped_class(var, **ignore_file_pattern_kwargs))
except: # noqa
except Exception:
all_available_modules.append(var)
else:
if has_from_pretrained and not hasattr(var,
@@ -370,9 +370,10 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
if has_get_config_dict and not hasattr(var,
'_get_config_dict_origin'):
var._get_config_dict_origin = var.get_config_dict
var.get_config_dict = classmethod(
partial(patch_get_config_dict,
**ignore_file_pattern_kwargs))
var.get_config_dict = partial(
patch_pretrained_model_name_or_path,
ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs)
all_available_modules.append(var)
return all_available_modules
@@ -618,6 +619,11 @@ def _patch_hub():
# Patch repocard.validate
from huggingface_hub import repocard
if not hasattr(repocard.RepoCard, '_validate_origin'):
def load(*args, **kwargs): # noqa
from huggingface_hub.errors import EntryNotFoundError
raise EntryNotFoundError(message='API not supported.')
repocard.RepoCard._validate_origin = repocard.RepoCard.validate
repocard.RepoCard.validate = lambda *args, **kwargs: None
repocard.RepoCard._load_origin = repocard.RepoCard.load

View File

@@ -1,13 +1,8 @@
import os
from typing import Optional, Union
import torch
from transformers import Pipeline as PipelineHF
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
from transformers.pipelines import check_task, get_task
from modelscope.hub import snapshot_download
from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
def _get_hf_device(device):
@@ -21,6 +16,7 @@ def _get_hf_device(device):
def _get_hf_pipeline_class(task, model):
from transformers.pipelines import check_task, get_task
if not task:
task = get_task(model)
normalized_task, targeted_task, task_options = check_task(task)
@@ -35,7 +31,9 @@ def hf_pipeline(
framework: Optional[str] = None,
device: Optional[Union[int, str, 'torch.device']] = None,
**kwargs,
) -> PipelineHF:
) -> 'transformers.Pipeline':
from transformers import pipeline
if isinstance(model, str):
if not os.path.exists(model):
model = snapshot_download(model)

View File

@@ -220,7 +220,7 @@ class RepoUtils:
@dataclass
class CommitInfo(str):
class CommitInfo:
"""Data structure containing information about a newly created commit.
Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`],
@@ -240,46 +240,12 @@ class CommitInfo(str):
oid (`str`):
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
pr_url (`str`, *optional*):
Url to the PR that has been created, if any. Populated when `create_pr=True`
is passed.
pr_revision (`str`, *optional*):
Revision of the PR that has been created, if any. Populated when
`create_pr=True` is passed. Example: `"refs/pr/1"`.
pr_num (`int`, *optional*):
Number of the PR discussion that has been created, if any. Populated when
`create_pr=True` is passed. Can be passed as `discussion_num` in
[`get_discussion_details`]. Example: `1`.
_url (`str`, *optional*):
Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by
[`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on
the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this
attribute. Please use `commit_url` instead.
"""
commit_url: str
commit_message: str
commit_description: str
oid: str
pr_url: Optional[str] = None
# Computed from `pr_url` in `__post_init__`
pr_revision: Optional[str] = field(init=False)
pr_num: Optional[str] = field(init=False)
# legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.)
_url: str = field(
repr=False, default=None) # type: ignore # defaults to `commit_url`
def __new__(cls,
*args,
commit_url: str,
_url: Optional[str] = None,
**kwargs):
return str.__new__(cls, _url or commit_url)
def to_dict(cls):
return {
@@ -287,7 +253,6 @@ class CommitInfo(str):
'commit_message': cls.commit_message,
'commit_description': cls.commit_description,
'oid': cls.oid,
'pr_url': cls.pr_url,
}

View File

@@ -17,8 +17,8 @@ class FaceRecognitionTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_face_compare(self):
img1 = 'data/test/images/face_recognition_1.png'
img2 = 'data/test/images/face_recognition_2.png'
img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png'
img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png'
face_recognition = pipeline(
Tasks.face_recognition, model=self.model_id)
@@ -27,6 +27,30 @@ class FaceRecognitionTest(unittest.TestCase):
sim = np.dot(emb1[0], emb2[0])
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_face_compare_use_det(self):
img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png'
img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png'
face_recognition = pipeline(
Tasks.face_recognition, model=self.model_id, use_det=True)
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING]
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING]
sim = np.dot(emb1[0], emb2[0])
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_face_compare_not_use_det(self):
img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png'
img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png'
face_recognition = pipeline(
Tasks.face_recognition, model=self.model_id, use_det=False)
emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING]
emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING]
sim = np.dot(emb1[0], emb2[0])
print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}')
if __name__ == '__main__':
unittest.main()

View File

@@ -122,6 +122,36 @@ class TestToOllama(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_check_template_type(self):
_test_check_tmpl_type(
'DevQuasar/CohereForAI.c4ai-command-r7b-arabic-02-2025-GGUF',
'command-r7b-arabic',
gguf_meta={
'general.name': 'CohereForAI.c4ai Command R7B Arabic 02 2025'
})
_test_check_tmpl_type(
'lmstudio-community/granite-vision-3.2-2b-GGUF',
'granite3.2-vision',
gguf_meta={'general.name': 'Granite Vision 3.2 2b'})
_test_check_tmpl_type(
'unsloth/Phi-4-mini-instruct-GGUF',
'phi4-mini',
gguf_meta={'general.name': 'Phi 4 Mini Instruct'})
_test_check_tmpl_type(
'lmstudio-community/granite-3.2-2b-instruct-GGUF',
'granite3.2',
gguf_meta={'general.name': 'Granite 3.2 2b Instruct'})
_test_check_tmpl_type(
'unsloth/r1-1776-GGUF',
'r1-1776',
gguf_meta={'general.name': 'R1 1776'})
_test_check_tmpl_type(
'QuantFactory/DeepScaleR-1.5B-Preview-GGUF',
'deepscaler',
gguf_meta={'general.name': 'DeepScaleR 1.5B Preview'})
_test_check_tmpl_type(
'lmstudio-community/OpenThinker-32B-GGUF',
'openthinker',
gguf_meta={'general.name': 'Qwen2.5 7B Instruct'})
_test_check_tmpl_type(
'LLM-Research/Llama-3.3-70B-Instruct',
'llama3.3',