diff --git a/.dev_scripts/build_image.sh b/.dev_scripts/build_image.sh index 9ce2a4a8..c1e61890 100644 --- a/.dev_scripts/build_image.sh +++ b/.dev_scripts/build_image.sh @@ -150,7 +150,7 @@ echo -e "Building image with:\npython$python_version\npytorch$torch_version\nten docker_file_content=`cat docker/Dockerfile.ubuntu` if [ "$is_ci_test" != "True" ]; then echo "Building ModelScope lib, will install ModelScope lib to image" - docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir -U transformers && pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl " + docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir numpy https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl && pip install --no-cache-dir -U transformers" fi echo "$is_dsw" if [ "$is_dsw" == "False" ]; then diff --git a/docker/Dockerfile.ubuntu b/docker/Dockerfile.ubuntu index 2af8994b..4ac4fd53 100644 --- a/docker/Dockerfile.ubuntu +++ b/docker/Dockerfile.ubuntu @@ -32,6 +32,7 @@ RUN pip install --no-cache-dir mpi4py paint_ldm \ mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 pai-easycv ms_swift \ ipykernel fasttext fairseq deepspeed -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html +ARG USE_GPU # for cpu install cpu version faiss, faiss depends on blas lib, we install libopenblas TODO rename gpu or cpu version faiss RUN if [ "$USE_GPU" = "True" ] ; then \ pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 faiss==1.7.2 safetensors typeguard==2.13.3 scikit-learn librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \ @@ -45,10 +46,14 @@ COPY examples /modelscope/examples # for pai-easycv setup compatiblity issue ENV SETUPTOOLS_USE_DISTUTILS=stdlib -RUN CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git' +RUN if [ "$USE_GPU" = "True" ] ; then \ + CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'; \ + else \ + echo 'cpu unsupport detectron2'; \ + fi # torchmetrics==0.11.4 for ofa -RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr +RUN pip install --no-cache-dir jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr COPY docker/scripts/install_flash_attension.sh /tmp/install_flash_attension.sh RUN if [ "$USE_GPU" = "True" ] ; then \ bash /tmp/install_flash_attension.sh; \ diff --git a/docker/scripts/install_flash_attension.sh b/docker/scripts/install_flash_attension.sh index 6a3301c2..f37e567d 100644 --- a/docker/scripts/install_flash_attension.sh +++ b/docker/scripts/install_flash_attension.sh @@ -1,6 +1,4 @@ - git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention && \ - cd flash-attention && pip install . && \ - pip install csrc/layer_norm && \ - pip install csrc/rotary && \ + git clone -b v2.3.2 https://github.com/Dao-AILab/flash-attention && \ + cd flash-attention && python setup.py install && \ cd .. && \ rm -rf flash-attention diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 8e6d4ae6..9beb156b 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -134,12 +134,13 @@ class Model(ABC): ignore_file_pattern=ignore_file_pattern) logger.info(f'initialize model from {local_model_dir}') + configuration_path = osp.join(local_model_dir, ModelFile.CONFIGURATION) + cfg = None if cfg_dict is not None: cfg = cfg_dict - else: - cfg = Config.from_file( - osp.join(local_model_dir, ModelFile.CONFIGURATION)) - task_name = cfg.task + elif os.path.exists(configuration_path): + cfg = Config.from_file(configuration_path) + task_name = getattr(cfg, 'task', None) if 'task' in kwargs: task_name = kwargs.pop('task') model_cfg = getattr(cfg, 'model', ConfigDict()) @@ -162,6 +163,9 @@ class Model(ABC): model = model.to(device) return model # use ms + if cfg is None: + raise FileNotFoundError( + f'`{ModelFile.CONFIGURATION}` file not found.') model_cfg.model_dir = local_model_dir # install and import remote repos before build diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index a3707918..25f948bc 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -181,8 +181,20 @@ class EpochBasedTrainer(BaseTrainer): compile_options = {} self.model = compile_model(self.model, **compile_options) - if 'work_dir' in kwargs: + if kwargs.get('work_dir', None) is not None: self.work_dir = kwargs['work_dir'] + if 'train' not in self.cfg: + self.cfg['train'] = ConfigDict() + self.cfg['train']['work_dir'] = self.work_dir + if 'checkpoint' in self.cfg['train']: + if 'period' in self.cfg['train']['checkpoint']: + self.cfg['train']['checkpoint']['period'][ + 'save_dir'] = self.work_dir + if 'best' in self.cfg['train']['checkpoint']: + self.cfg['train']['checkpoint']['best'][ + 'save_dir'] = self.work_dir + if 'logging' in self.cfg['train']: + self.cfg['train']['logging']['out_dir'] = self.work_dir else: self.work_dir = self.cfg.train.get('work_dir', './work_dir') diff --git a/modelscope/utils/automodel_utils.py b/modelscope/utils/automodel_utils.py index afd83817..1f5de3b6 100644 --- a/modelscope/utils/automodel_utils.py +++ b/modelscope/utils/automodel_utils.py @@ -6,8 +6,11 @@ from modelscope.utils.ast_utils import INDEX_KEY from modelscope.utils.import_utils import LazyImportModule -def can_load_by_ms(model_dir: str, tast_name: str, model_type: str) -> bool: - if ('MODELS', tast_name, +def can_load_by_ms(model_dir: str, task_name: Optional[str], + model_type: Optional[str]) -> bool: + if model_type is None or task_name is None: + return False + if ('MODELS', task_name, model_type) in LazyImportModule.AST_INDEX[INDEX_KEY]: return True ms_wrapper_path = os.path.join(model_dir, 'ms_wrapper.py') @@ -25,11 +28,27 @@ def _can_load_by_hf_automodel(automodel_class: type, config) -> bool: return False -def get_hf_automodel_class(model_dir: str, task_name: str) -> Optional[type]: - from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, - AutoModelForSeq2SeqLM, - AutoModelForTokenClassification, - AutoModelForSequenceClassification) +def get_default_automodel(config) -> Optional[type]: + import modelscope.utils.hf_util as hf_util + if not hasattr(config, 'auto_map'): + return None + auto_map = config.auto_map + automodel_list = [k for k in auto_map.keys() if k.startswith('AutoModel')] + if len(automodel_list) == 1: + return getattr(hf_util, automodel_list[0]) + if len(automodel_list) > 1 and len( + set([auto_map[k] for k in automodel_list])) == 1: + return getattr(hf_util, automodel_list[0]) + return None + + +def get_hf_automodel_class(model_dir: str, + task_name: Optional[str]) -> Optional[type]: + from modelscope.utils.hf_util import (AutoConfig, AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForTokenClassification, + AutoModelForSequenceClassification) automodel_mapping = { Tasks.backbone: AutoModel, Tasks.chat: AutoModelForCausalLM, @@ -37,19 +56,18 @@ def get_hf_automodel_class(model_dir: str, task_name: str) -> Optional[type]: Tasks.text_classification: AutoModelForSequenceClassification, Tasks.token_classification: AutoModelForTokenClassification, } - automodel_class = automodel_mapping.get(task_name, None) - if automodel_class is None: - return None config_path = os.path.join(model_dir, 'config.json') if not os.path.exists(config_path): return None try: - try: - config = AutoConfig.from_pretrained( - model_dir, trust_remote_code=True) - except (FileNotFoundError, ValueError): - return None + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + if task_name is None: + automodel_class = get_default_automodel(config) + else: + automodel_class = automodel_mapping.get(task_name, None) + if automodel_class is None: + return None if _can_load_by_hf_automodel(automodel_class, config): return automodel_class if (automodel_class is AutoModelForCausalLM @@ -71,14 +89,5 @@ def try_to_load_hf_model(model_dir: str, task_name: str, model = None if automodel_class is not None: # use hf - device_map = kwargs.get('device_map', None) - torch_dtype = kwargs.get('torch_dtype', None) - config = kwargs.get('config', None) - - model = automodel_class.from_pretrained( - model_dir, - device_map=device_map, - torch_dtype=torch_dtype, - config=config, - trust_remote_code=True) + model = automodel_class.from_pretrained(model_dir, **kwargs) return model diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index e3e8cac8..463dcea7 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -21,7 +21,7 @@ from transformers.models.auto.tokenization_auto import ( TOKENIZER_MAPPING_NAMES, get_tokenizer_config) from modelscope import snapshot_download -from modelscope.utils.constant import Invoke +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke try: from transformers import GPTQConfig as GPTQConfigHF @@ -84,69 +84,6 @@ patch_tokenizer_base() patch_model_base() -def check_hf_code(model_dir: str, auto_class: type, - trust_remote_code: bool) -> None: - config_path = os.path.join(model_dir, 'config.json') - if not os.path.exists(config_path): - raise FileNotFoundError(f'{config_path} is not found') - config_dict = PretrainedConfig.get_config_dict(config_path)[0] - auto_class_name = auto_class.__name__ - if auto_class is AutoTokenizerHF: - tokenizer_config = get_tokenizer_config(model_dir) - # load from repo - if trust_remote_code: - has_remote_code = False - if auto_class is AutoTokenizerHF: - auto_map = tokenizer_config.get('auto_map', None) - if auto_map is not None: - module_name = auto_map.get(auto_class_name, None) - if module_name is not None: - module_name = module_name[0] - has_remote_code = True - else: - auto_map = config_dict.get('auto_map', None) - if auto_map is not None: - module_name = auto_map.get(auto_class_name, None) - has_remote_code = module_name is not None - - if has_remote_code: - module_path = os.path.join(model_dir, - module_name.split('.')[0] + '.py') - if not os.path.exists(module_path): - raise FileNotFoundError(f'{module_path} is not found') - return - - # trust_remote_code is False or has_remote_code is False - model_type = config_dict.get('model_type', None) - if model_type is None: - raise ValueError(f'`model_type` key is not found in {config_path}.') - - trust_remote_code_info = '.' - if not trust_remote_code: - trust_remote_code_info = ', You can try passing `trust_remote_code=True`.' - if auto_class is AutoConfigHF: - if model_type not in CONFIG_MAPPING: - raise ValueError( - f'{model_type} not found in HF `CONFIG_MAPPING`{trust_remote_code_info}' - ) - elif auto_class is AutoTokenizerHF: - tokenizer_class = tokenizer_config.get('tokenizer_class') - if tokenizer_class is not None: - return - if model_type not in TOKENIZER_MAPPING_NAMES: - raise ValueError( - f'{model_type} not found in HF `TOKENIZER_MAPPING_NAMES`{trust_remote_code_info}' - ) - else: - mapping_names = [ - m.model_type for m in auto_class._model_mapping.keys() - ] - if model_type not in mapping_names: - raise ValueError( - f'{model_type} not found in HF `auto_class._model_mapping`{trust_remote_code_info}' - ) - - def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): """Get a custom wrapper class for auto classes to download the models from the ModelScope hub Args: @@ -166,7 +103,7 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): ignore_file_pattern = kwargs.pop('ignore_file_pattern', default_ignore_file_pattern) if not os.path.exists(pretrained_model_name_or_path): - revision = kwargs.pop('revision', None) + revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION) model_dir = snapshot_download( pretrained_model_name_or_path, revision=revision, @@ -175,9 +112,6 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): else: model_dir = pretrained_model_name_or_path - if module_class is not GenerationConfigHF: - trust_remote_code = kwargs.get('trust_remote_code', False) - check_hf_code(model_dir, module_class, trust_remote_code) module_obj = module_class.from_pretrained(model_dir, *model_args, **kwargs) diff --git a/modelscope/version.py b/modelscope/version.py index 23ef0243..0ec59aaa 100644 --- a/modelscope/version.py +++ b/modelscope/version.py @@ -1,5 +1,5 @@ # Make sure to modify __release_datetime__ to release time when making official release. -__version__ = '1.9.1' +__version__ = '1.9.3' # default release datetime for branches under active development is set # to be a time far-far-away-into-the-future __release_datetime__ = '2099-09-06 00:00:00'