diff --git a/docker/Dockerfile.ubuntu b/docker/Dockerfile.ubuntu index a294d2c0..4f6186f0 100644 --- a/docker/Dockerfile.ubuntu +++ b/docker/Dockerfile.ubuntu @@ -65,6 +65,8 @@ RUN sh /tmp/install.sh {version_args} && \ pip config set install.trusted-host mirrors.aliyun.com && \ cp /tmp/resources/ubuntu2204.aliyun /etc/apt/sources.list +RUN pip install --no-cache-dir omegaconf==2.3.0 && pip cache purge + ENV SETUPTOOLS_USE_DISTUTILS=stdlib ENV VLLM_USE_MODELSCOPE=True ENV LMDEPLOY_USE_MODELSCOPE=True diff --git a/modelscope/__init__.py b/modelscope/__init__.py index a1fbf444..2579ca71 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -56,7 +56,7 @@ if TYPE_CHECKING: AutoModelForPreTraining, AutoModelForTextEncoding, AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration, T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel, - LlamaPreTrainedModel, LlamaForCausalLM) + LlamaPreTrainedModel, LlamaForCausalLM, hf_pipeline) else: print( 'transformer is not installed, please install it if you want to use related modules' diff --git a/modelscope/cli/cli.py b/modelscope/cli/cli.py index afadfa91..1cb98ea0 100644 --- a/modelscope/cli/cli.py +++ b/modelscope/cli/cli.py @@ -40,9 +40,6 @@ def run_cmd(): if not hasattr(args, 'func'): parser.print_help() exit(1) - if args.token is not None: - api = HubApi() - api.login(args.token) cmd = args.func(args) cmd.execute() diff --git a/modelscope/cli/download.py b/modelscope/cli/download.py index 6b430453..723fdfe9 100644 --- a/modelscope/cli/download.py +++ b/modelscope/cli/download.py @@ -3,6 +3,7 @@ import os from argparse import ArgumentParser from modelscope.cli.base import CLICommand +from modelscope.hub.api import HubApi from modelscope.hub.constants import DEFAULT_MAX_WORKERS from modelscope.hub.file_download import (dataset_file_download, model_file_download) @@ -54,16 +55,21 @@ class DownloadCMD(CLICommand): default='model', help="Type of repo to download from (defaults to 'model').", ) + parser.add_argument( + '--token', + type=str, + default=None, + help='Optional. Access token to download controlled entities.') parser.add_argument( '--revision', type=str, default=None, - help='Revision of the model.') + help='Revision of the entity (e.g., model).') parser.add_argument( '--cache_dir', type=str, default=None, - help='Cache directory to save model.') + help='Cache directory to save entity (e.g., model).') parser.add_argument( '--local_dir', type=str, @@ -118,6 +124,10 @@ class DownloadCMD(CLICommand): % self.args.repo_type) if not self.args.model and not self.args.dataset: raise Exception('Model or dataset must be set.') + cookies = None + if self.args.token is not None: + api = HubApi() + cookies = api.get_cookies(access_token=self.args.token) if self.args.model: if len(self.args.files) == 1: # download single file model_file_download( @@ -125,7 +135,8 @@ class DownloadCMD(CLICommand): self.args.files[0], cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, - revision=self.args.revision) + revision=self.args.revision, + cookies=cookies) elif len( self.args.files) > 1: # download specified multiple files. snapshot_download( @@ -135,7 +146,7 @@ class DownloadCMD(CLICommand): local_dir=self.args.local_dir, allow_file_pattern=self.args.files, max_workers=self.args.max_workers, - ) + cookies=cookies) else: # download repo snapshot_download( self.args.model, @@ -145,7 +156,7 @@ class DownloadCMD(CLICommand): allow_file_pattern=convert_patterns(self.args.include), ignore_file_pattern=convert_patterns(self.args.exclude), max_workers=self.args.max_workers, - ) + cookies=cookies) elif self.args.dataset: dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION if len(self.args.files) == 1: # download single file @@ -154,7 +165,8 @@ class DownloadCMD(CLICommand): self.args.files[0], cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, - revision=dataset_revision) + revision=dataset_revision, + cookies=cookies) elif len( self.args.files) > 1: # download specified multiple files. dataset_snapshot_download( @@ -164,7 +176,7 @@ class DownloadCMD(CLICommand): local_dir=self.args.local_dir, allow_file_pattern=self.args.files, max_workers=self.args.max_workers, - ) + cookies=cookies) else: # download repo dataset_snapshot_download( self.args.dataset, @@ -174,6 +186,6 @@ class DownloadCMD(CLICommand): allow_file_pattern=convert_patterns(self.args.include), ignore_file_pattern=convert_patterns(self.args.exclude), max_workers=self.args.max_workers, - ) + cookies=cookies) else: pass # noop diff --git a/modelscope/cli/upload.py b/modelscope/cli/upload.py index 453a6314..a050c5b2 100644 --- a/modelscope/cli/upload.py +++ b/modelscope/cli/upload.py @@ -91,7 +91,7 @@ class UploadCMD(CLICommand): '--endpoint', type=str, default=get_endpoint(), - help='Endpoint for Modelscope service.') + help='Endpoint for ModelScope service.') parser.set_defaults(func=subparser_func) @@ -137,14 +137,15 @@ class UploadCMD(CLICommand): # Check token and login # The cookies will be reused if the user has logged in before. + cookies = None api = HubApi(endpoint=self.args.endpoint) - if self.args.token: - api.login(access_token=self.args.token) - cookies = ModelScopeConfig.get_cookies() + cookies = api.get_cookies(access_token=self.args.token) + else: + cookies = ModelScopeConfig.get_cookies() if cookies is None: raise ValueError( - 'The `token` is not provided! ' + 'No credential found for entity upload. ' 'You can pass the `--token` argument, ' 'or use api.login(access_token=`your_sdk_token`). ' 'Your token is available at https://modelscope.cn/my/myaccesstoken' @@ -158,6 +159,7 @@ class UploadCMD(CLICommand): repo_type=self.args.repo_type, commit_message=self.args.commit_message, commit_description=self.args.commit_description, + token=self.args.token, ) elif os.path.isdir(self.local_path): api.upload_folder( @@ -170,6 +172,7 @@ class UploadCMD(CLICommand): allow_patterns=convert_patterns(self.args.include), ignore_patterns=convert_patterns(self.args.exclude), max_workers=self.args.max_workers, + token=self.args.token, ) else: raise ValueError(f'{self.local_path} is not a valid local path') diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index f5a2f39b..3a249f26 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -34,6 +34,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES, API_RESPONSE_FIELD_USERNAME, DEFAULT_CREDENTIALS_PATH, DEFAULT_MAX_WORKERS, + DEFAULT_MODELSCOPE_DOMAIN, MODELSCOPE_CLOUD_ENVIRONMENT, MODELSCOPE_CLOUD_USERNAME, MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS, @@ -112,9 +113,19 @@ class HubApi: self.upload_checker = UploadingCheck() + def get_cookies(self, access_token): + from requests.cookies import RequestsCookieJar + jar = RequestsCookieJar() + jar.set('m_session_id', + access_token, + domain=os.getenv('MODELSCOPE_DOMAIN', + DEFAULT_MODELSCOPE_DOMAIN), + path='/') + return jar + def login( self, - access_token: Optional[str] = None, + access_token: Optional[str] = None ): """Login with your SDK access token, which can be obtained from https://www.modelscope.cn user center. diff --git a/modelscope/hub/check_model.py b/modelscope/hub/check_model.py index e41a0a17..6d39c275 100644 --- a/modelscope/hub/check_model.py +++ b/modelscope/hub/check_model.py @@ -71,6 +71,10 @@ def check_local_model_is_latest( headers=snapshot_header, use_cookies=cookies, ) + model_cache = None + # download via non-git method + if not os.path.exists(os.path.join(model_root_path, '.git')): + model_cache = ModelFileSystemCache(model_root_path) for model_file in model_files: if model_file['Type'] == 'tree': continue diff --git a/modelscope/hub/git.py b/modelscope/hub/git.py index 144d9d69..d03ca773 100644 --- a/modelscope/hub/git.py +++ b/modelscope/hub/git.py @@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton): git_env = os.environ.copy() git_env['GIT_TERMINAL_PROMPT'] = '0' command = [self.git_path, *args] + command = [item for item in command if item] response = subprocess.run( command, stdout=subprocess.PIPE, diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 596d6d22..7e5cc6b5 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -10,6 +10,8 @@ from modelscope.utils.config import ConfigDict, check_config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks, ThirdParty) from modelscope.utils.hub import read_config +from modelscope.utils.import_utils import is_transformers_available +from modelscope.utils.logger import get_logger from modelscope.utils.plugins import (register_modelhub_repo, register_plugins_repo) from modelscope.utils.registry import Registry, build_from_cfg @@ -17,6 +19,7 @@ from .base import Pipeline from .util import is_official_hub_path PIPELINES = Registry('pipelines') +logger = get_logger() def normalize_model_input(model, @@ -72,7 +75,7 @@ def pipeline(task: str = None, config_file: str = None, pipeline_name: str = None, framework: str = None, - device: str = 'gpu', + device: str = None, model_revision: Optional[str] = DEFAULT_MODEL_REVISION, ignore_file_pattern: List[str] = None, **kwargs) -> Pipeline: @@ -109,6 +112,7 @@ def pipeline(task: str = None, if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') + pipeline_props = None if pipeline_name is None: # get default pipeline for this task if isinstance(model, str) \ @@ -157,8 +161,11 @@ def pipeline(task: str = None, if pipeline_name: pipeline_props = {'type': pipeline_name} else: - check_config(cfg) - pipeline_props = cfg.pipeline + try: + check_config(cfg) + pipeline_props = cfg.pipeline + except AssertionError as e: + logger.info(str(e)) elif model is not None: # get pipeline info from Model object @@ -166,9 +173,13 @@ def pipeline(task: str = None, if not hasattr(first_model, 'pipeline'): # model is instantiated by user, we should parse config again cfg = read_config(first_model.model_dir) - check_config(cfg) - first_model.pipeline = cfg.pipeline - pipeline_props = first_model.pipeline + try: + check_config(cfg) + first_model.pipeline = cfg.pipeline + except AssertionError as e: + logger.info(str(e)) + if first_model.__dict__.get('pipeline'): + pipeline_props = first_model.pipeline else: pipeline_name, default_model_repo = get_default_pipeline_info(task) model = normalize_model_input(default_model_repo, model_revision) @@ -176,6 +187,23 @@ def pipeline(task: str = None, else: pipeline_props = {'type': pipeline_name} + if not pipeline_props and is_transformers_available(): + try: + from modelscope.utils.hf_util import hf_pipeline + return hf_pipeline( + task=task, + model=model, + framework=framework, + device=device, + **kwargs) + except Exception as e: + logger.error( + 'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,' + ' but that also failed.') + raise e + + if not device: + device = 'gpu' pipeline_props['model'] = model pipeline_props['device'] = device cfg = ConfigDict(pipeline_props) diff --git a/modelscope/pipelines/cv/face_processing_base_pipeline.py b/modelscope/pipelines/cv/face_processing_base_pipeline.py index b9b81c9c..b2c376ca 100644 --- a/modelscope/pipelines/cv/face_processing_base_pipeline.py +++ b/modelscope/pipelines/cv/face_processing_base_pipeline.py @@ -22,7 +22,7 @@ logger = get_logger() class FaceProcessingBasePipeline(Pipeline): - def __init__(self, model: str, **kwargs): + def __init__(self, model: str, use_det=True, **kwargs): """ use `model` to create a face processing pipeline and output cropped img, scores, bbox and lmks. @@ -30,11 +30,13 @@ class FaceProcessingBasePipeline(Pipeline): model: model id on modelscope hub. """ + self.use_det = use_det super().__init__(model=model, **kwargs) # face detect pipeline - det_model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd' - self.face_detection = pipeline( - Tasks.face_detection, model=det_model_id) + if use_det: + det_model_id = 'damo/cv_ddsar_face-detection_iclr23-damofd' + self.face_detection = pipeline( + Tasks.face_detection, model=det_model_id) def _choose_face(self, det_result, @@ -94,21 +96,27 @@ class FaceProcessingBasePipeline(Pipeline): def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_ndarray(input) img = img[:, :, ::-1] - det_result = self.face_detection(img.copy()) - rtn = self._choose_face(det_result, img_shape=img.shape) - if rtn is not None: - scores, bboxes, face_lmks = rtn - face_lmks = face_lmks.reshape(5, 2) - align_img, _ = align_face(img, (112, 112), face_lmks) + if self.use_det: + det_result = self.face_detection(img.copy()) + rtn = self._choose_face(det_result, img_shape=img.shape) + if rtn is not None: + scores, bboxes, face_lmks = rtn + face_lmks = face_lmks.reshape(5, 2) + align_img, _ = align_face(img, (112, 112), face_lmks) - result = {} - result['img'] = np.ascontiguousarray(align_img) - result['scores'] = [scores] - result['bbox'] = bboxes - result['lmks'] = face_lmks - return result + result = {} + result['img'] = np.ascontiguousarray(align_img) + result['scores'] = [scores] + result['bbox'] = bboxes + result['lmks'] = face_lmks + return result + else: + return None else: - return None + result = {} + resized_img = cv2.resize(img, (112, 112)) + result['img'] = np.ascontiguousarray(resized_img) + return result def align_face_padding(self, img, rect, padding_size=16, pad_pixel=127): rect = np.reshape(rect, (-1, 4)) diff --git a/modelscope/pipelines/cv/face_recognition_pipeline.py b/modelscope/pipelines/cv/face_recognition_pipeline.py index 4af5a04f..8f595aef 100644 --- a/modelscope/pipelines/cv/face_recognition_pipeline.py +++ b/modelscope/pipelines/cv/face_recognition_pipeline.py @@ -14,10 +14,11 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.face_processing_base_pipeline import \ + FaceProcessingBasePipeline from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger -from . import FaceProcessingBasePipeline logger = get_logger() @@ -26,15 +27,14 @@ logger = get_logger() Tasks.face_recognition, module_name=Pipelines.face_recognition) class FaceRecognitionPipeline(FaceProcessingBasePipeline): - def __init__(self, model: str, **kwargs): + def __init__(self, model: str, use_det=True, **kwargs): """ use `model` to create a face recognition pipeline for prediction Args: model: model id on modelscope hub. """ - # face recong model - super().__init__(model=model, **kwargs) + super().__init__(model=model, use_det=use_det, **kwargs) device = torch.device( f'cuda:{0}' if torch.cuda.is_available() else 'cpu') self.device = device diff --git a/modelscope/preprocessors/templates/loader.py b/modelscope/preprocessors/templates/loader.py index 4e61feb0..0d97ecf4 100644 --- a/modelscope/preprocessors/templates/loader.py +++ b/modelscope/preprocessors/templates/loader.py @@ -18,6 +18,7 @@ class TemplateInfo: template: str = None template_regex: str = None modelfile_prefix: str = None + allow_general_name: bool = True def cases(*names): @@ -255,6 +256,12 @@ template_info = [ modelfile_prefix= 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi3', ), + TemplateInfo( + template_regex= + f'.*{cases("phi4-mini", "phi-4-mini")}.*', + modelfile_prefix= + 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi4-mini', + ), TemplateInfo( template_regex= f'.*{cases("phi4", "phi-4")}{no_multi_modal()}.*', @@ -470,6 +477,12 @@ template_info = [ modelfile_prefix= 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/command-r-plus', ), + TemplateInfo( + template_regex= + f'.*{cases("command-r7b-arabic")}.*', + modelfile_prefix= + 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/command-r7b-arabic', + ), TemplateInfo( template_regex= f'.*{cases("command-r7b")}.*', @@ -663,9 +676,17 @@ template_info = [ modelfile_prefix= 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3-guardian'), TemplateInfo( - template_regex=f'.*{cases("granite")}.*{cases("code")}.*', + template_regex=f'.*{cases("granite")}.*{cases("code")}.*', modelfile_prefix= 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite-code'), + TemplateInfo( + template_regex=f'.*{cases("granite")}.*{cases("vision")}.*{cases("3.2")}.*', + modelfile_prefix= + 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.2-vision'), + TemplateInfo( + template_regex=f'.*{cases("granite")}.*{cases("3.2")}.*', + modelfile_prefix= + 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/granite3.2'), TemplateInfo( template_regex=f'.*{cases("granite-3.1")}.*{cases("2b", "8b")}.*', modelfile_prefix= @@ -733,6 +754,12 @@ template_info = [ modelfile_prefix= 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/smallthinker'), + TemplateInfo( + template_regex=f'.*{cases("openthinker")}.*', + modelfile_prefix= + 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/openthinker', + allow_general_name=False), + TemplateInfo( template_regex= f'.*{cases("olmo2", "olmo-2")}.*', @@ -888,8 +915,14 @@ template_info = [ template_regex=f'.*{cases("exaone")}.*', modelfile_prefix= 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/exaone3.5'), - - + TemplateInfo( + template_regex=f'.*{cases("r1-1776")}.*', + modelfile_prefix= + 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/r1-1776'), + TemplateInfo( + template_regex=f'.*{cases("deepscaler")}.*', + modelfile_prefix= + 'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepscaler'), ] @@ -1015,33 +1048,50 @@ class TemplateLoader: f'Please make sure you model_id: {model_id} ' f'and template_name: {template_name} is supported.') logger.info('Exporting to ollama:') - names = [] + names = {} + match_infos = {} if gguf_meta: gguf_header_name = gguf_meta.get("general.name", None) - names.append(gguf_header_name) + if gguf_header_name: + names['gguf_header_name'] = gguf_header_name if model_id: - names.append(model_id) - for name in names: + names['model_id'] = model_id + for name_type, name in names.items(): for _info in template_info: if re.fullmatch(_info.template_regex, name): if _info.modelfile_prefix and not kwargs.get('ignore_oss_model_file', False): - template_str = TemplateLoader._read_content_from_url( - _info.modelfile_prefix + '.template') - if not template_str: - logger.info(f'{name} has no template file.') - params = TemplateLoader._read_content_from_url(_info.modelfile_prefix + '.params') - if params: - params = json.loads(params) - else: - logger.info(f'{name} has no params file.') - license = TemplateLoader._read_content_from_url( - _info.modelfile_prefix + '.license') - if not template_str: - logger.info(f'{name} has no license file.') - format_out = TemplateLoader._format_return(template_str, params, split, license) - if debug: - return format_out, _info - return format_out + match_infos[name_type] = name, _info + break + + _name = None + _info = None + if len(match_infos) == 1: + _, (_name, _info) = match_infos.popitem() + elif len(match_infos) > 1: + if not match_infos['model_id'][1].allow_general_name: + _name, _info = match_infos['model_id'] + else: + _name, _info = match_infos['gguf_header_name'] + + if _info: + template_str = TemplateLoader._read_content_from_url( + _info.modelfile_prefix + '.template') + if not template_str: + logger.info(f'{_name} has no template file.') + params = TemplateLoader._read_content_from_url(_info.modelfile_prefix + '.params') + if params: + params = json.loads(params) + else: + logger.info(f'{_name} has no params file.') + license = TemplateLoader._read_content_from_url( + _info.modelfile_prefix + '.license') + if not template_str: + logger.info(f'{_name} has no license file.') + format_out = TemplateLoader._format_return(template_str, params, split, license) + if debug: + return format_out, _info + return format_out + if template_name: template = TemplateLoader.load_by_template_name( template_name, **kwargs) diff --git a/modelscope/utils/hf_util/__init__.py b/modelscope/utils/hf_util/__init__.py index a138ff7a..ac8349c9 100644 --- a/modelscope/utils/hf_util/__init__.py +++ b/modelscope/utils/hf_util/__init__.py @@ -1,2 +1,3 @@ from .auto_class import * from .patcher import patch_context, patch_hub, unpatch_hub +from .pipeline_builder import hf_pipeline diff --git a/modelscope/utils/hf_util/patcher.py b/modelscope/utils/hf_util/patcher.py index 28f8eeb5..6a41a5ce 100644 --- a/modelscope/utils/hf_util/patcher.py +++ b/modelscope/utils/hf_util/patcher.py @@ -27,7 +27,8 @@ def get_all_imported_modules(): transformers_include_names = [ 'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*', 'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig', - 'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast' + 'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast', + 'Pipeline' ] peft_include_names = ['.*PeftModel.*', '.*Config'] diffusers_include_names = ['^(?!TF|Flax).*Pipeline$'] @@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): model_dir, *model_args, **kwargs) return module_obj + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = True, + **kwargs, + ): + push_to_hub = kwargs.pop('push_to_hub', False) + if push_to_hub: + from modelscope.hub.push_to_hub import push_to_hub + from modelscope.hub.api import HubApi + from modelscope.hub.repository import Repository + + token = kwargs.get('token') + commit_message = kwargs.pop('commit_message', None) + repo_name = kwargs.pop( + 'repo_id', + save_directory.split(os.path.sep)[-1]) + + api = HubApi() + api.login(token) + api.create_repo(repo_name) + # clone the repo + Repository(save_directory, repo_name) + + super().save_pretrained( + save_directory=save_directory, + safe_serialization=safe_serialization, + push_to_hub=False, + **kwargs) + + # Class members may be unpatched, so push_to_hub is done separately here + if push_to_hub: + push_to_hub( + repo_name=repo_name, + output_dir=save_directory, + commit_message=commit_message, + token=token) + if not hasattr(module_class, 'from_pretrained'): del ClassWrapper.from_pretrained else: @@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): if not hasattr(module_class, 'get_config_dict'): del ClassWrapper.get_config_dict + if not hasattr(module_class, 'save_pretrained'): + del ClassWrapper.save_pretrained + ClassWrapper.__name__ = module_class.__name__ ClassWrapper.__qualname__ = module_class.__qualname__ return ClassWrapper @@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): has_from_pretrained = hasattr(var, 'from_pretrained') has_get_peft_type = hasattr(var, '_get_peft_type') has_get_config_dict = hasattr(var, 'get_config_dict') + has_save_pretrained = hasattr(var, 'save_pretrained') except: # noqa continue - if wrap: + # save_pretrained is not a classmethod and cannot be overridden by replacing + # the class method. It requires replacing the class object method. + if wrap or ('pipeline' in name.lower() and has_save_pretrained): try: - if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type: + if (not has_from_pretrained and not has_get_config_dict + and not has_get_peft_type and not has_save_pretrained): all_available_modules.append(var) else: all_available_modules.append( diff --git a/modelscope/utils/hf_util/pipeline_builder.py b/modelscope/utils/hf_util/pipeline_builder.py new file mode 100644 index 00000000..5386bead --- /dev/null +++ b/modelscope/utils/hf_util/pipeline_builder.py @@ -0,0 +1,54 @@ +import os +from typing import Optional, Union + +import torch +from transformers import Pipeline as PipelineHF +from transformers import PreTrainedModel, TFPreTrainedModel, pipeline +from transformers.pipelines import check_task, get_task + +from modelscope.hub import snapshot_download +from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub + + +def _get_hf_device(device): + if isinstance(device, str): + device_name = device.lower() + eles = device_name.split(':') + if eles[0] == 'gpu': + eles = ['cuda'] + eles[1:] + device = ''.join(eles) + return device + + +def _get_hf_pipeline_class(task, model): + if not task: + task = get_task(model) + normalized_task, targeted_task, task_options = check_task(task) + pipeline_class = targeted_task['impl'] + pipeline_class = _patch_pretrained_class([pipeline_class])[0] + return pipeline_class + + +def hf_pipeline( + task: str = None, + model: Optional[Union[str, 'PreTrainedModel', 'TFPreTrainedModel']] = None, + framework: Optional[str] = None, + device: Optional[Union[int, str, 'torch.device']] = None, + **kwargs, +) -> PipelineHF: + if isinstance(model, str): + if not os.path.exists(model): + model = snapshot_download(model) + + framework = 'pt' if framework == 'pytorch' else framework + + device = _get_hf_device(device) + pipeline_class = _get_hf_pipeline_class(task, model) + + return pipeline( + task=task, + model=model, + framework=framework, + device=device, + pipeline_class=pipeline_class, + **kwargs) diff --git a/modelscope/utils/repo_utils.py b/modelscope/utils/repo_utils.py index 85ddc2f7..55d16251 100644 --- a/modelscope/utils/repo_utils.py +++ b/modelscope/utils/repo_utils.py @@ -220,7 +220,7 @@ class RepoUtils: @dataclass -class CommitInfo(str): +class CommitInfo: """Data structure containing information about a newly created commit. Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`], @@ -240,46 +240,12 @@ class CommitInfo(str): oid (`str`): Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. - pr_url (`str`, *optional*): - Url to the PR that has been created, if any. Populated when `create_pr=True` - is passed. - - pr_revision (`str`, *optional*): - Revision of the PR that has been created, if any. Populated when - `create_pr=True` is passed. Example: `"refs/pr/1"`. - - pr_num (`int`, *optional*): - Number of the PR discussion that has been created, if any. Populated when - `create_pr=True` is passed. Can be passed as `discussion_num` in - [`get_discussion_details`]. Example: `1`. - - _url (`str`, *optional*): - Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by - [`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on - the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this - attribute. Please use `commit_url` instead. """ commit_url: str commit_message: str commit_description: str oid: str - pr_url: Optional[str] = None - - # Computed from `pr_url` in `__post_init__` - pr_revision: Optional[str] = field(init=False) - pr_num: Optional[str] = field(init=False) - - # legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.) - _url: str = field( - repr=False, default=None) # type: ignore # defaults to `commit_url` - - def __new__(cls, - *args, - commit_url: str, - _url: Optional[str] = None, - **kwargs): - return str.__new__(cls, _url or commit_url) def to_dict(cls): return { @@ -287,7 +253,6 @@ class CommitInfo(str): 'commit_message': cls.commit_message, 'commit_description': cls.commit_description, 'oid': cls.oid, - 'pr_url': cls.pr_url, } diff --git a/tests/pipelines/test_face_recognition.py b/tests/pipelines/test_face_recognition.py index 7b84590c..bdbdb849 100644 --- a/tests/pipelines/test_face_recognition.py +++ b/tests/pipelines/test_face_recognition.py @@ -17,8 +17,8 @@ class FaceRecognitionTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_face_compare(self): - img1 = 'data/test/images/face_recognition_1.png' - img2 = 'data/test/images/face_recognition_2.png' + img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png' + img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png' face_recognition = pipeline( Tasks.face_recognition, model=self.model_id) @@ -27,6 +27,30 @@ class FaceRecognitionTest(unittest.TestCase): sim = np.dot(emb1[0], emb2[0]) print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_face_compare_use_det(self): + img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png' + img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png' + + face_recognition = pipeline( + Tasks.face_recognition, model=self.model_id, use_det=True) + emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] + emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] + sim = np.dot(emb1[0], emb2[0]) + print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_face_compare_not_use_det(self): + img1 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_1.png' + img2 = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/face_recognition_2.png' + + face_recognition = pipeline( + Tasks.face_recognition, model=self.model_id, use_det=False) + emb1 = face_recognition(img1)[OutputKeys.IMG_EMBEDDING] + emb2 = face_recognition(img2)[OutputKeys.IMG_EMBEDDING] + sim = np.dot(emb1[0], emb2[0]) + print(f'Cos similarity={sim:.3f}, img1:{img1} img2:{img2}') + if __name__ == '__main__': unittest.main() diff --git a/tests/tools/test_to_ollama.py b/tests/tools/test_to_ollama.py index e7a686d6..a7f8275d 100644 --- a/tests/tools/test_to_ollama.py +++ b/tests/tools/test_to_ollama.py @@ -122,6 +122,36 @@ class TestToOllama(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_check_template_type(self): + _test_check_tmpl_type( + 'DevQuasar/CohereForAI.c4ai-command-r7b-arabic-02-2025-GGUF', + 'command-r7b-arabic', + gguf_meta={ + 'general.name': 'CohereForAI.c4ai Command R7B Arabic 02 2025' + }) + _test_check_tmpl_type( + 'lmstudio-community/granite-vision-3.2-2b-GGUF', + 'granite3.2-vision', + gguf_meta={'general.name': 'Granite Vision 3.2 2b'}) + _test_check_tmpl_type( + 'unsloth/Phi-4-mini-instruct-GGUF', + 'phi4-mini', + gguf_meta={'general.name': 'Phi 4 Mini Instruct'}) + _test_check_tmpl_type( + 'lmstudio-community/granite-3.2-2b-instruct-GGUF', + 'granite3.2', + gguf_meta={'general.name': 'Granite 3.2 2b Instruct'}) + _test_check_tmpl_type( + 'unsloth/r1-1776-GGUF', + 'r1-1776', + gguf_meta={'general.name': 'R1 1776'}) + _test_check_tmpl_type( + 'QuantFactory/DeepScaleR-1.5B-Preview-GGUF', + 'deepscaler', + gguf_meta={'general.name': 'DeepScaleR 1.5B Preview'}) + _test_check_tmpl_type( + 'lmstudio-community/OpenThinker-32B-GGUF', + 'openthinker', + gguf_meta={'general.name': 'Qwen2.5 7B Instruct'}) _test_check_tmpl_type( 'LLM-Research/Llama-3.3-70B-Instruct', 'llama3.3', diff --git a/tests/utils/test_hf_util.py b/tests/utils/test_hf_util.py index 9826d991..058e92c7 100644 --- a/tests/utils/test_hf_util.py +++ b/tests/utils/test_hf_util.py @@ -40,6 +40,14 @@ class HFUtilTest(unittest.TestCase): with open(self.test_file2, 'w') as f: f.write('{}') + self.pipeline_qa_context = r""" + Extractive Question Answering is the task of extracting an answer from a text given a question. An example + of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would + like to fine-tune a model on a SQuAD task, you may leverage the + examples/pytorch/question-answering/run_squad.py script. + """ + self.pipeline_qa_question = 'What is a good example of a question answering dataset?' + def tearDown(self): logger.info('TearDown') shutil.rmtree(self.model_dir, ignore_errors=True) @@ -235,6 +243,59 @@ class HFUtilTest(unittest.TestCase): 'Qwen/Qwen1.5-0.5B-Chat', trust_remote_code=True) model.push_to_hub(self.create_model_name) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_pipeline_model_id(self): + from modelscope import pipeline + model_id = 'damotestx/distilbert-base-cased-distilled-squad' + qa = pipeline('question-answering', model=model_id) + assert qa( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_pipeline_auto_model(self): + from modelscope import pipeline, AutoModelForQuestionAnswering, AutoTokenizer + model_id = 'damotestx/distilbert-base-cased-distilled-squad' + model = AutoModelForQuestionAnswering.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + qa = pipeline('question-answering', model=model, tokenizer=tokenizer) + assert qa( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_pipeline_save_pretrained(self): + from modelscope import pipeline + model_id = 'damotestx/distilbert-base-cased-distilled-squad' + + pipe_ori = pipeline('question-answering', model=model_id) + + result_ori = pipe_ori( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + # save_pretrained + repo_id = self.create_model_name + save_dir = './tmp_test_hf_pipeline' + try: + os.system(f'rm -rf {save_dir}') + self.api.delete_model(repo_id) + # wait for delete repo + import time + time.sleep(5) + except Exception: + # if repo not exists + pass + pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id) + + # load from saved + pipe_new = pipeline('question-answering', model=repo_id) + result_new = pipe_new( + question=self.pipeline_qa_question, + context=self.pipeline_qa_context) + + assert result_new == result_ori + if __name__ == '__main__': unittest.main()