diff --git a/examples/pytorch/baichuan/finetune_baichuan.py b/examples/pytorch/baichuan/finetune_baichuan.py index 353f5023..ed8821c6 100644 --- a/examples/pytorch/baichuan/finetune_baichuan.py +++ b/examples/pytorch/baichuan/finetune_baichuan.py @@ -3,14 +3,13 @@ import sys import types from dataclasses import dataclass, field +from swift import LoRAConfig, Swift from transformers import AutoModelForCausalLM, AutoTokenizer from modelscope import (EpochBasedTrainer, MsDataset, TorchModel, TrainingArgs, build_dataset_from_file, snapshot_download) from modelscope.metainfo import Trainers from modelscope.preprocessors import TextGenerationTransformersPreprocessor -from modelscope.swift import Swift -from modelscope.swift.lora import LoRAConfig from modelscope.trainers import build_trainer DEFAULT_PAD_TOKEN = '[PAD]' @@ -205,12 +204,12 @@ preprocessor = TextGenerationTransformersPreprocessor( if args.use_lora != 0: lora_config = LoRAConfig( - replace_modules=['pack'], - rank=args.lora_rank, + target_modules=['pack'], + r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout) model = model.bfloat16() - Swift.prepare_model(model, lora_config) + model = Swift.prepare_model(model, lora_config) kwargs = dict( model=model, diff --git a/examples/pytorch/baichuan/lora_inference.py b/examples/pytorch/baichuan/lora_inference.py index 661e8493..7458c572 100644 --- a/examples/pytorch/baichuan/lora_inference.py +++ b/examples/pytorch/baichuan/lora_inference.py @@ -1,10 +1,9 @@ import os.path as osp import torch +from swift import LoRAConfig, Swift from modelscope.pipelines import pipeline -from modelscope.swift import Swift -from modelscope.swift.lora import LoRAConfig from modelscope.utils.constant import Tasks # 使用源模型 model_id 初始化 pipeline @@ -12,11 +11,11 @@ model_id = 'baichuan-inc/baichuan-7B' pipe = pipeline( task=Tasks.text_generation, model=model_id, model_revision='v1.0.2') # lora 配置,replace_modules,rank,alpha 需与训练参数相同 -lora_config = LoRAConfig(replace_modules=['pack'], rank=32, lora_alpha=32) +lora_config = LoRAConfig(target_modules=['pack'], r=32, lora_alpha=32) # 转 bf16,需与训练精度相同 model = pipe.model.bfloat16() # model 转 lora -Swift.prepare_model(model, lora_config) +model = Swift.prepare_model(model, lora_config) # 加载 lora 参数,默认 link 到于 output/model 路径 work_dir = './tmp' state_dict = torch.load(osp.join(work_dir, 'output/pytorch_model.bin')) diff --git a/examples/pytorch/chatglm6b/finetune.py b/examples/pytorch/chatglm6b/finetune.py index 5e7ff6a5..1f419770 100644 --- a/examples/pytorch/chatglm6b/finetune.py +++ b/examples/pytorch/chatglm6b/finetune.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field import numpy as np import torch from chatglm_trainer import Seq2SeqTrainer +from swift import LoRAConfig, Swift from text_generation_metric import TextGenerationMetric from transformers import DataCollatorForSeq2Seq @@ -11,8 +12,6 @@ from modelscope import build_dataset_from_file, snapshot_download from modelscope.metainfo import Models from modelscope.models import Model from modelscope.msdatasets import MsDataset -from modelscope.swift import Swift -from modelscope.swift.lora import LoRAConfig from modelscope.trainers.training_args import TrainingArgs from modelscope.utils.config import ConfigDict from modelscope.utils.hub import read_config @@ -243,15 +242,15 @@ elif not args.use_lora: if args.use_lora != 0: lora_config = LoRAConfig( - replace_modules=['attention.query_key_value'], - rank=args.lora_rank, + target_modules=['attention.query_key_value'], + r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout) if args.use_amp: model = model.float() else: model = model.bfloat16() - Swift.prepare_model(model, lora_config) + model = Swift.prepare_model(model, lora_config) prefix = args.source_prefix if args.source_prefix is not None else '' diff --git a/examples/pytorch/chatglm6b/lora_inference.py b/examples/pytorch/chatglm6b/lora_inference.py index d31ee50d..2975a5da 100644 --- a/examples/pytorch/chatglm6b/lora_inference.py +++ b/examples/pytorch/chatglm6b/lora_inference.py @@ -1,15 +1,17 @@ +import os.path as osp + +import torch +from swift import LoRAConfig, Swift + from modelscope import Model, pipeline, read_config from modelscope.metainfo import Models -from modelscope.swift import Swift -from modelscope.swift.lora import LoRAConfig from modelscope.utils.config import ConfigDict lora_config = LoRAConfig( - replace_modules=['attention.query_key_value'], - rank=32, + target_modules=['attention.query_key_value'], + r=32, lora_alpha=32, - lora_dropout=0.05, - pretrained_weights='./lora_dureader_target/iter_600.pth') + lora_dropout=0.05) model_dir = 'ZhipuAI/ChatGLM-6B' model_config = read_config(model_dir) @@ -19,8 +21,12 @@ model_config['model'] = ConfigDict({ model = Model.from_pretrained(model_dir, cfg_dict=model_config) model = model.bfloat16() -Swift.prepare_model(model, lora_config) - +model = Swift.prepare_model(model, lora_config) +work_dir = './tmp' +state_dict = torch.load(osp.join(work_dir, 'iter_600.pth')) +model = Swift.from_pretrained( + model, osp.join(work_dir, 'output_best'), device_map='auto') +model.load_state_dict(state_dict) pipe = pipeline('chat', model, pipeline_name='chatglm6b-text-generation') print( diff --git a/examples/pytorch/chatglm6b/lora_inference_v2.py b/examples/pytorch/chatglm6b/lora_inference_v2.py index aa86e890..9be481f1 100644 --- a/examples/pytorch/chatglm6b/lora_inference_v2.py +++ b/examples/pytorch/chatglm6b/lora_inference_v2.py @@ -1,15 +1,17 @@ +import os.path as osp + +import torch +from swift import LoRAConfig, Swift + from modelscope import Model, pipeline, read_config from modelscope.metainfo import Models -from modelscope.swift import Swift -from modelscope.swift.lora import LoRAConfig from modelscope.utils.config import ConfigDict lora_config = LoRAConfig( - replace_modules=['attention.query_key_value'], - rank=32, + target_modules=['attention.query_key_value'], + r=32, lora_alpha=32, - lora_dropout=0.05, - pretrained_weights='./lora_dureader_target/iter_600.pth') + lora_dropout=0.05) model_dir = 'ZhipuAI/chatglm2-6b' model_config = read_config(model_dir) @@ -19,7 +21,12 @@ model_config['model'] = ConfigDict({ model = Model.from_pretrained(model_dir, cfg_dict=model_config) model = model.bfloat16() -Swift.prepare_model(model, lora_config) +model = Swift.prepare_model(model, lora_config) +work_dir = './tmp' +state_dict = torch.load(osp.join(work_dir, 'iter_600.pth')) +model = Swift.from_pretrained( + model, osp.join(work_dir, 'output_best'), device_map='auto') +model.load_state_dict(state_dict) pipe = pipeline('chat', model, pipeline_name='chatglm2_6b-text-generation') diff --git a/examples/pytorch/llama/finetune_llama.py b/examples/pytorch/llama/finetune_llama.py index 41606a62..cb98662e 100644 --- a/examples/pytorch/llama/finetune_llama.py +++ b/examples/pytorch/llama/finetune_llama.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field import json import torch +from swift import LoRAConfig, Swift from modelscope import TrainingArgs from modelscope.hub.snapshot_download import snapshot_download @@ -15,8 +16,6 @@ from modelscope.metainfo import Trainers from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \ TorchCustomDataset -from modelscope.swift import Swift -from modelscope.swift.lora import LoRAConfig from modelscope.trainers import build_trainer IGNORE_INDEX = -100 @@ -255,12 +254,12 @@ if __name__ == '__main__': if args.use_lora != 0: lora_config = LoRAConfig( - replace_modules=['q_proj', 'k_proj', 'v_proj'], - rank=args.lora_rank, + target_modules=['q_proj', 'k_proj', 'v_proj'], + r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout) model = model.bfloat16() - Swift.prepare_model(model, lora_config) + model = Swift.prepare_model(model, lora_config) tokenizer = LlamaTokenizer.from_pretrained( model_path, diff --git a/examples/pytorch/llm/llm_infer.py b/examples/pytorch/llm/llm_infer.py index c2b8922b..e417f6f5 100644 --- a/examples/pytorch/llm/llm_infer.py +++ b/examples/pytorch/llm/llm_infer.py @@ -7,13 +7,13 @@ from functools import partial from typing import List, Optional import torch +from swift import LoRAConfig, Swift from transformers import GenerationConfig, TextStreamer from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING, get_dataset, get_model_tokenizer, inference, parse_args, process_dataset, tokenize_function) from modelscope import get_logger -from modelscope.swift import LoRAConfig, Swift warnings.warn( 'This directory has been migrated to ' @@ -76,13 +76,15 @@ def llm_infer(args: InferArguments) -> None: # ### Preparing lora if args.sft_type == 'lora': lora_config = LoRAConfig( - replace_modules=args.lora_target_modules, - rank=args.lora_rank, + target_modules=args.lora_target_modules, + r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout_p, pretrained_weights=args.ckpt_path) logger.info(f'lora_config: {lora_config}') model = Swift.prepare_model(model, lora_config) + state_dict = torch.load(args.ckpt_path, map_location='cpu') + model.load_state_dict(state_dict) elif args.sft_type == 'full': state_dict = torch.load(args.ckpt_path, map_location='cpu') model.load_state_dict(state_dict) diff --git a/examples/pytorch/llm/llm_sft.py b/examples/pytorch/llm/llm_sft.py index 827f9d80..8eaa6040 100644 --- a/examples/pytorch/llm/llm_sft.py +++ b/examples/pytorch/llm/llm_sft.py @@ -20,6 +20,7 @@ from functools import partial from typing import List, Optional import torch +from swift import LoRAConfig, Swift from torch import Tensor from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING, data_collate_fn, get_dataset, get_model_tokenizer, @@ -29,7 +30,6 @@ from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING, tokenize_function) from modelscope import get_logger -from modelscope.swift import LoRAConfig, Swift from modelscope.trainers import EpochBasedTrainer from modelscope.utils.config import Config @@ -141,8 +141,8 @@ def llm_sft(args: SftArguments) -> None: # ### Preparing lora if args.sft_type == 'lora': lora_config = LoRAConfig( - replace_modules=args.lora_target_modules, - rank=args.lora_rank, + target_modules=args.lora_target_modules, + r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout_p) logger.info(f'lora_config: {lora_config}') diff --git a/examples/pytorch/llm_agent/_common.py b/examples/pytorch/llm_agent/_common.py index dd07ef31..384e8106 100644 --- a/examples/pytorch/llm_agent/_common.py +++ b/examples/pytorch/llm_agent/_common.py @@ -5,47 +5,33 @@ import os import random import re import sys -from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import json import matplotlib.pyplot as plt import numpy as np # import torch -import torch.nn as nn -import torch.optim as optim -from matplotlib.axes import Axes from matplotlib.figure import Figure -from numpy import ndarray +from swift import LoRAConfig, Swift from tensorboard.backend.event_processing.event_accumulator import \ EventAccumulator from torch import Tensor from torch import device as Device -from torch import dtype as Dtype from torch.nn import Module -from torch.nn.parameter import Parameter from torch.nn.utils.rnn import pad_sequence -from torch.optim import Optimizer -from torch.optim import lr_scheduler as lrs -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import Dataset # from torchmetrics import Accuracy, MeanMetric # from tqdm import tqdm # -from modelscope import (Model, MsDataset, get_logger, read_config, - snapshot_download) +from modelscope import Model, MsDataset, get_logger, read_config from modelscope.metrics.base import Metric from modelscope.metrics.builder import METRICS from modelscope.models.nlp.chatglm2 import ChatGLM2Tokenizer from modelscope.msdatasets.dataset_cls.custom_datasets import \ TorchCustomDataset -from modelscope.swift import LoRAConfig, Swift -from modelscope.trainers import EpochBasedTrainer -from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.config import ConfigDict from modelscope.utils.registry import default_group # diff --git a/examples/pytorch/llm_agent/baichuan_infer.ipynb b/examples/pytorch/llm_agent/baichuan_infer.ipynb index 7ef29951..2ab8e64c 100644 --- a/examples/pytorch/llm_agent/baichuan_infer.ipynb +++ b/examples/pytorch/llm_agent/baichuan_infer.ipynb @@ -209,8 +209,8 @@ "LORA_ALPHA = 32\n", "LORA_DROPOUT_P = 0 # Arbitrary value\n", "lora_config = LoRAConfig(\n", - " replace_modules=LORA_TARGET_MODULES,\n", - " rank=LORA_RANK,\n", + " target_modules=LORA_TARGET_MODULES,\n", + " r=LORA_RANK,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT_P,\n", " pretrained_weights=CKPT_FAPTH)\n", diff --git a/examples/pytorch/llm_agent/baichuan_sft.ipynb b/examples/pytorch/llm_agent/baichuan_sft.ipynb index 6c41ff25..69bb2887 100644 --- a/examples/pytorch/llm_agent/baichuan_sft.ipynb +++ b/examples/pytorch/llm_agent/baichuan_sft.ipynb @@ -224,8 +224,8 @@ "LORA_ALPHA = 32\n", "LORA_DROPOUT_P = 0.1\n", "lora_config = LoRAConfig(\n", - " replace_modules=LORA_TARGET_MODULES,\n", - " rank=LORA_RANK,\n", + " target_modules=LORA_TARGET_MODULES,\n", + " r=LORA_RANK,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT_P)\n", "logger.info(f'lora_config: {lora_config}')\n", diff --git a/examples/pytorch/llm_agent/chatglm2_infer.ipynb b/examples/pytorch/llm_agent/chatglm2_infer.ipynb index 821da5e6..84a25d3e 100644 --- a/examples/pytorch/llm_agent/chatglm2_infer.ipynb +++ b/examples/pytorch/llm_agent/chatglm2_infer.ipynb @@ -212,8 +212,8 @@ "LORA_ALPHA = 32\n", "LORA_DROPOUT_P = 0 # Arbitrary value\n", "lora_config = LoRAConfig(\n", - " replace_modules=LORA_TARGET_MODULES,\n", - " rank=LORA_RANK,\n", + " target_modules=LORA_TARGET_MODULES,\n", + " r=LORA_RANK,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT_P,\n", " pretrained_weights=CKPT_FAPTH)\n", diff --git a/examples/pytorch/llm_agent/chatglm2_sft.ipynb b/examples/pytorch/llm_agent/chatglm2_sft.ipynb index f1943086..b32f2f36 100644 --- a/examples/pytorch/llm_agent/chatglm2_sft.ipynb +++ b/examples/pytorch/llm_agent/chatglm2_sft.ipynb @@ -234,8 +234,8 @@ "LORA_ALPHA = 32\n", "LORA_DROPOUT_P = 0.1\n", "lora_config = LoRAConfig(\n", - " replace_modules=LORA_TARGET_MODULES,\n", - " rank=LORA_RANK,\n", + " target_modules=LORA_TARGET_MODULES,\n", + " r=LORA_RANK,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT_P)\n", "logger.info(f'lora_config: {lora_config}')\n", diff --git a/modelscope/metrics/prediction_saving_wrapper.py b/modelscope/metrics/prediction_saving_wrapper.py index c7aee4e1..9cabf559 100644 --- a/modelscope/metrics/prediction_saving_wrapper.py +++ b/modelscope/metrics/prediction_saving_wrapper.py @@ -2,16 +2,10 @@ from typing import Dict -import numpy as np -from sklearn.metrics import accuracy_score, f1_score - from modelscope.metainfo import Metrics -from modelscope.outputs import OutputKeys from modelscope.utils.registry import default_group -from modelscope.utils.tensor_utils import (torch_nested_detach, - torch_nested_numpify) from .base import Metric -from .builder import METRICS, MetricKeys +from .builder import METRICS @METRICS.register_module( diff --git a/modelscope/swift/control_sd_lora.py b/modelscope/models/multi_modal/efficient_diffusion_tuning/control_sd_lora.py similarity index 100% rename from modelscope/swift/control_sd_lora.py rename to modelscope/models/multi_modal/efficient_diffusion_tuning/control_sd_lora.py diff --git a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py index f253ebbe..901c44d9 100644 --- a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py +++ b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py @@ -13,22 +13,20 @@ from diffusers import (AutoencoderKL, DDPMScheduler, DiffusionPipeline, utils) from diffusers.models import cross_attention from diffusers.utils import deprecation_utils +from swift import AdapterConfig, LoRAConfig, PromptConfig, Swift from transformers import CLIPTextModel, CLIPTokenizer from modelscope import snapshot_download from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.efficient_diffusion_tuning.sd_lora import \ + LoRATuner from modelscope.outputs import OutputKeys -from modelscope.swift import Swift -from modelscope.swift.adapter import AdapterConfig -from modelscope.swift.control_sd_lora import ControlLoRATuner -from modelscope.swift.lora import LoRAConfig -from modelscope.swift.prompt import PromptConfig -from modelscope.swift.sd_lora import LoRATuner from modelscope.utils.checkpoint import save_checkpoint, save_configuration from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks +from .control_sd_lora import ControlLoRATuner utils.deprecate = lambda *arg, **kwargs: None deprecation_utils.deprecate = lambda *arg, **kwargs: None diff --git a/modelscope/swift/sd_lora.py b/modelscope/models/multi_modal/efficient_diffusion_tuning/sd_lora.py similarity index 100% rename from modelscope/swift/sd_lora.py rename to modelscope/models/multi_modal/efficient_diffusion_tuning/sd_lora.py diff --git a/modelscope/swift/__init__.py b/modelscope/swift/__init__.py deleted file mode 100644 index bd8ea75e..00000000 --- a/modelscope/swift/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import TYPE_CHECKING - -from modelscope.utils.import_utils import LazyImportModule - -if TYPE_CHECKING: - from .optimizers.child_tuning_adamw_optimizer import calculate_fisher, ChildTuningAdamW - from .adapter import Adapter, AdapterConfig, AdapterModule - from .lora import LoRA, LoRAConfig, Linear, MergedLinear, Embedding, Conv2d - from .prompt import Prompt, PromptConfig, PromptModule - from .control_sd_lora import ControlLoRACrossAttnProcessor, ControlLoRACrossAttnProcessorV2, ControlLoRATuner - from .base import SwiftConfig, Swift -else: - _import_structure = { - 'optimizers.child_tuning_adamw_optimizer': - ['calculate_fisher', 'ChildTuningAdamW'], - 'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'], - 'lora': [ - 'LoRA', 'LoRAConfig', 'Linear', 'MergedLinear', 'Embedding', - 'Conv2d' - ], - 'prompt': ['Prompt', 'PromptConfig', 'PromptModule'], - 'control_sd_lora': [ - 'ControlLoRACrossAttnProcessor', 'ControlLoRACrossAttnProcessorV2', - 'ControlLoRATuner' - ], - 'base': ['SwiftConfig', 'Swift'] - } - - import sys - - sys.modules[__name__] = LazyImportModule( - __name__, - globals()['__file__'], - _import_structure, - module_spec=__spec__, - extra_objects={}, - ) diff --git a/modelscope/swift/adapter.py b/modelscope/swift/adapter.py deleted file mode 100644 index 2f1f729b..00000000 --- a/modelscope/swift/adapter.py +++ /dev/null @@ -1,201 +0,0 @@ -import inspect -import os -import re -import types -from dataclasses import dataclass, field -from typing import Union - -import torch -from torch import nn - -from modelscope import snapshot_download -from modelscope.utils.constant import ModelFile -from .base import SwiftConfig - - -@dataclass -class AdapterConfig(SwiftConfig): - """ - The configuration class for the adapter module. - - Adapters project input tokens by an MLP layer. - 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) - See http://arxiv.org/abs/1902.00751 - - Args: - dim: The dimension of the hidden states - module_name: The feedforward module to be replaced, in regex format - hidden_pos: The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs) - method_name: The method to be replaced, default to replace the forward method - adapter_length: The length of the adapter length (intermediate length) - act_layer: The activation layer of the adapter - only_adapter_trainable: Whether to train only adapters - pretrained_weights: The pretrained adapter weights. - Can be a local dir, local file, or a model id from modelscope - """ - - dim: int = field(metadata={'help': 'The dimension of the hidden states'}) - - module_name: str = field( - metadata={ - 'help': 'The feedforward module to be replaced, in regex format' - }) - - hidden_pos: Union[str, int] = field( - metadata={ - 'help': - 'The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs)' - }) - - method_name: str = field( - default='forward', - metadata={ - 'help': - 'The method to be replaced, default to replace the forward method' - }) - - adapter_length: int = field( - default=128, - metadata={ - 'help': 'The length of the adapter length (intermediate length)' - }) - - act_layer: nn.Module = field( - default=nn.GELU, - metadata={'help': 'The activation layer of the adapter'}) - - only_adapter_trainable: bool = field( - default=True, metadata={'help': 'Whether to train only adapters'}) - - pretrained_weights: str = field( - default=None, - metadata={ - 'help': - 'The pretrained adapter weights. Can be a local dir, local file, or a model id from modelscope' - }) - - -class Adapter: - - @staticmethod - def prepare_model(model: nn.Module, config: AdapterConfig): - module_keys = [key for key, _ in model.named_modules()] - - for module_key in module_keys: - if re.fullmatch(config.module_name, module_key): # noqa - module = model.get_submodule(module_key) - - def _forward(self, *args, **kwargs): - args = self.forward_origin(*args, **kwargs) - if isinstance(args, (tuple, list, dict)): - if isinstance(config.hidden_pos, int): - return args[0:config.hidden_pos] + args[ - config.hidden_pos] + getattr(self, 'adapter')(args[config.hidden_pos]) \ - + args[config.hidden_pos + 1:] # noqa - else: - kwargs[config.hidden_pos] = args[ - config.hidden_pos] + getattr(self, 'adapter')( - args[config.hidden_pos]) - elif isinstance(args, torch.Tensor): - args = getattr(self, 'adapter')(args) - return args - - def _feed_forward_chunk(self, attention_output): - return _forward(self, attention_output) - - module.forward_origin = getattr(module, config.method_name) - num_args_in_forward_chunk_fn = len( - inspect.signature(module.forward_origin).parameters) - if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1: - setattr(module, config.method_name, - types.MethodType(_feed_forward_chunk, module)) - else: - setattr(module, config.method_name, - types.MethodType(_forward, module)) - - if isinstance(module, torch.nn.Linear): - input_dim = module.out_features - else: - input_dim = config.dim - - adapter_module = AdapterModule(input_dim, - config.adapter_length, - config.act_layer) - setattr(module, 'adapter', adapter_module) - - if config.only_adapter_trainable: - for n, p in model.named_parameters(): - if 'adapter' not in n: - p.requires_grad = False - - def state_dict_hook(module, destination, prefix, local_metadata): - return { - key: value - for key, value in destination.items() if 'adapter' in key - } - - model.state_dict_hook_handle = model._register_state_dict_hook( - state_dict_hook) - - def load_state_dict(self, state_dict, strict=True): - return self.load_state_dict_origin(state_dict, False) - - model.load_state_dict_origin = model.load_state_dict - model.load_state_dict = types.MethodType(load_state_dict, model) - - if config.pretrained_weights is not None: - if not os.path.exists(config.pretrained_weights): - model_dir = snapshot_download(config.pretrained_weights) - pretrained_weights = os.path.join( - model_dir, ModelFile.TORCH_MODEL_BIN_FILE) - elif os.path.isfile(config.pretrained_weights): - pretrained_weights = config.pretrained_weights - else: - pretrained_weights = os.path.join( - config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) - model.load_state_dict(torch.load(pretrained_weights)) - return model - - -class AdapterModule(nn.Module): - """The implementation of adapter tuning method. - - Adapters project input tokens by an MLP layer. - 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019) - See http://arxiv.org/abs/1902.00751 - - Attributes: - dim: An integer indicating the embedding dimension. - adapter_length: An integer indicating the length of adapter tuning. - """ - - def __init__( - self, - dim, - adapter_length=None, - act_layer=nn.GELU, - ): - super(AdapterModule, self).__init__() - self.dim = dim - self.adapter_length = adapter_length - # self.adapter_type = adapter_type - self.ln1 = nn.Linear(dim, adapter_length) - self.activate = act_layer() - self.ln2 = nn.Linear(adapter_length, dim) - self.init_weights() - - def init_weights(self): - - def _init_weights(m): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - nn.init.normal_(m.bias, std=1e-6) - - self.apply(_init_weights) - - def forward(self, x, identity=None): - out = self.ln2(self.activate(self.ln1(x))) - if identity is None: - identity = x - out = identity + out - return out diff --git a/modelscope/swift/base.py b/modelscope/swift/base.py deleted file mode 100644 index 441521ca..00000000 --- a/modelscope/swift/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class SwiftConfig: - pass - - -class Swift: - - @staticmethod - def prepare_model(model, config: SwiftConfig): - """Prepare the module and returns the new module. - - Args: - model: The model to tune. - config: The config of the tuner. - - Returns: - The tuned model. - """ - from .lora import LoRA, LoRAConfig - from .adapter import Adapter, AdapterConfig - from .prompt import Prompt, PromptConfig - if isinstance(config, LoRAConfig): - return LoRA.prepare_model(model, config) - elif isinstance(config, AdapterConfig): - return Adapter.prepare_model(model, config) - elif isinstance(config, PromptConfig): - return Prompt.prepare_model(model, config) - return None diff --git a/modelscope/swift/lora.py b/modelscope/swift/lora.py deleted file mode 100644 index 3c0be6ba..00000000 --- a/modelscope/swift/lora.py +++ /dev/null @@ -1,700 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. -import logging -import math -import os.path -import re -import types -from dataclasses import dataclass, field -from typing import Dict, List - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from modelscope import snapshot_download -from modelscope.utils.constant import ModelFile -from .base import SwiftConfig - -logger = logging.getLogger(__name__) - - -@dataclass -class LoRAConfig(SwiftConfig): - """ - The configuration class for the loRA module. - - Args: - rank: The rank of the LoRA module - replace_modules: The modules to be replaced by LoRA, can be the end of the module name or a regex string - lora_alpha: The factor to add the lora weights - lora_dropout: The dropout rate of the lora module - merge_weights: Whether to merge weights when validating - use_merged_linear: Whether to replace with merged linear layer - enable_lora: The modules need to be turned on when using the merged linear layer - fan_in_fan_out: Set this to True if the layer to replace stores weight like (fan_in, fan_out) - bias: Bias type. Values ca be "none", "all" or "lora_only" - only_lora_trainable: Whether to train only lora - pretrained_weights: The pretrained lora weights. - Can be a local dir, local file, or a model id from modelscope - """ - - rank: int = field( - default=6, metadata={'help': 'The rank of the LoRA module'}) - replace_modules: List = field( - default=None, - metadata={ - 'help': - 'The modules to be replaced by LoRA, can be the end of the module name or a regex string' - }) - lora_alpha: float = field( - default=1., metadata={'help': 'The factor to add the lora weights'}) - lora_dropout: float = field( - default=0., metadata={'help': 'The dropout rate of the lora module'}) - merge_weights: bool = field( - default=True, - metadata={'help': 'Whether to merge weights when validating'}) - use_merged_linear: bool = field( - default=False, - metadata={'help': 'Whether to replace with merged linear layer'}) - enable_lora: List = field( - default=None, - metadata={ - 'help': - 'The modules need to be turned on when using the merged linear layer' - }) - fan_in_fan_out: bool = field( - default=False, - metadata={ - 'help': - 'Set this to True if the layer to replace stores weight like (fan_in, fan_out)' - }) - bias: str = field( - default='none', - metadata={ - 'help': 'Bias type. Values ca be "none", "all" or "lora_only"' - }) - only_lora_trainable: bool = field( - default=True, metadata={'help': 'Whether to train only lora'}) - pretrained_weights: str = field( - default=None, - metadata={ - 'help': - 'The pretrained lora weights. Can be a local dir, local file, or a model id from modelscope' - }) - - -class LoRA: - - @staticmethod - def prepare_model(model: nn.Module, config: LoRAConfig): - """Tune a model with LoRA. - - Args: - config: The LoRAConfig instance. - - Returns: - The lora modules - """ - LoRA._dynamic_patch_lora( - model, - replace_modules=config.replace_modules, - r=config.rank, - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, - merge_weights=config.merge_weights, - use_merged_linear=config.use_merged_linear, - enable_lora=config.enable_lora, - fan_in_fan_out=config.fan_in_fan_out) - - if config.only_lora_trainable: - mark_only_lora_as_trainable(model, config.bias) - - def state_dict_hook(module, destination, prefix, local_metadata): - return lora_state_dict(destination, config.bias) - - model.state_dict_hook_handle = model._register_state_dict_hook( - state_dict_hook) - - def load_state_dict(self, state_dict, strict=True): - return self.load_state_dict_origin(state_dict, False) - - model.load_state_dict_origin = model.load_state_dict - model.load_state_dict = types.MethodType(load_state_dict, model) - - if config.pretrained_weights is not None: - if not os.path.exists(config.pretrained_weights): - model_dir = snapshot_download(config.pretrained_weights) - pretrained_weights = os.path.join( - model_dir, ModelFile.TORCH_MODEL_BIN_FILE) - elif os.path.isfile(config.pretrained_weights): - pretrained_weights = config.pretrained_weights - else: - pretrained_weights = os.path.join( - config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) - model.load_state_dict(torch.load(pretrained_weights)) - - return model - - @staticmethod - def _dynamic_patch_lora(model, replace_modules, use_merged_linear, - **kwargs): - """Dynamic patch lora to model - - Args: - model: The torch.nn.Module containing the target module to be patched. - replace_modules: The module names to be replaced, the replacing strategy is `end with`. - use_merged_linear: Whether to replace with merged linear layer - **kwargs: The arguments passed from `tune` which are needed by lora. - - Returns: - The lora modules - """ - modules = [] - module_keys = [key for key, _ in model.named_modules()] - assert isinstance(replace_modules, (str, list)) - if isinstance(replace_modules, str): - replace_modules = [replace_modules] - - for module_key in module_keys: - if isinstance(replace_modules, str): - target_module_found = re.fullmatch(replace_modules, module_key) - else: - target_module_found = any( - module_key.endswith(target_key) - for target_key in replace_modules) - if target_module_found: # noqa - parts = module_key.split('.') - module = model.get_submodule('.'.join(parts[:-1])) - sub_module = model.get_submodule(module_key) - _key = parts[-1] - - lora_module = None - if isinstance(sub_module, torch.nn.Linear): - if use_merged_linear: - lora_module = MergedLinear( - sub_module.in_features, - sub_module.out_features, - bias=sub_module.bias is not None, - **kwargs) - else: - kwargs.pop('enable_lora', None) - lora_module = Linear( - sub_module.in_features, - sub_module.out_features, - bias=sub_module.bias is not None, - **kwargs) - elif isinstance(sub_module, torch.nn.Conv2d): - kwargs.pop('fan_in_fan_out', None) - lora_module = Conv2d( - sub_module.in_channels, - sub_module.out_channels, - kernel_size=sub_module.kernel_size, - stride=sub_module.stride, - padding=sub_module.padding, - dilation=sub_module.dilation, - groups=sub_module.groups, - **kwargs) - - if lora_module is not None: - lora_module.weight = sub_module.weight - if sub_module.bias is not None: - lora_module.bias = sub_module.bias - lora_module.to(sub_module.weight.device).to( - sub_module.weight.dtype) - setattr(module, _key, lora_module) - modules.append(lora_module) - return modules - - @staticmethod - def unpatch_lora(model, config: LoRAConfig): - """Unpatch lora modules and merge the weights to original modules. - - LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network. - 'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021) - See https://arxiv.org/abs/2106.09685 - - Args: - model: The model called with `tune` function. - replace_modules: The module names to be replaced, the replacing strategy is `end with`. - - Returns: - The lora modules. - """ - modules = [] - module_keys = [key for key, _ in model.named_modules()] - assert isinstance(config.replace_modules, (str, list)) - replace_modules = config.replace_modules - - for module_key in module_keys: - if isinstance(replace_modules, str): - target_module_found = re.fullmatch(replace_modules, module_key) - else: - target_module_found = any( - module_key.endswith(target_key) - for target_key in replace_modules) - if target_module_found: # noqa - parts = module_key.split('.') - module = model.get_submodule('.'.join(parts[:-1])) - sub_module = model.get_submodule(module_key) - _key = parts[-1] - - origin_module = None - if isinstance(sub_module, Linear): - origin_module = torch.nn.Linear( - sub_module.in_features, - sub_module.out_features, - bias=sub_module.bias is not None) - elif isinstance(sub_module, Conv2d): - origin_module = torch.nn.Conv2d( - sub_module.in_channels, - sub_module.out_channels, - kernel_size=sub_module.kernel_size, - stride=sub_module.stride, - padding=sub_module.padding, - dilation=sub_module.dilation, - groups=sub_module.groups) - - if origin_module is not None: - sub_module.merge_weights = True - sub_module.eval() - origin_module.weight = sub_module.weight - if sub_module.bias is not None: - origin_module.bias = sub_module.bias - origin_module.to(sub_module.weight.device).to( - sub_module.weight.dtype) - setattr(module, _key, origin_module) - modules.append(sub_module) - - model.state_dict_hook_handle.remove() - if hasattr(model, 'load_state_dict_hook_handle'): - model.load_state_dict_hook_handle.remove() - else: - model.load_state_dict = model.load_state_dict_origin - return modules - - -class LoRALayer: - - def __init__( - self, - r: int, - lora_alpha: int, - lora_dropout: float, - merge_weights: bool, - ): - self.r = r - self.lora_alpha = lora_alpha - # Optional dropout - if lora_dropout > 0.: - self.lora_dropout = nn.Dropout(p=lora_dropout) - else: - self.lora_dropout = lambda x: x - # Mark the weight as unmerged - self.merged = False - self.merge_weights = merge_weights - - -class Embedding(nn.Embedding, LoRALayer): - # LoRA implemented in a dense layer - def __init__(self, - num_embeddings: int, - embedding_dim: int, - r: int = 0, - lora_alpha: int = 1, - merge_weights: bool = True, - **kwargs): - nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) - LoRALayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=0, - merge_weights=merge_weights) - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter( - self.weight.new_zeros((r, num_embeddings))) - self.lora_B = nn.Parameter( - self.weight.new_zeros((embedding_dim, r))) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - - def reset_parameters(self): - nn.Embedding.reset_parameters(self) - if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.zeros_(self.lora_A) - nn.init.normal_(self.lora_B) - - def train(self, mode: bool = True): - nn.Embedding.train(self, mode) - self.lora_A.requires_grad = mode - self.lora_B.requires_grad = mode - if mode and self.merge_weights and self.merged: - # Make sure that the weights are not merged - if self.r > 0: - self.weight.data -= (self.lora_B - @ self.lora_A).T * self.scaling - self.merged = False - if not mode and self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0: - self.weight.data += (self.lora_B - @ self.lora_A).T * self.scaling - self.merged = True - - def eval(self): - nn.Embedding.eval(self) - self.lora_A.requires_grad = False - self.lora_B.requires_grad = False - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0: - self.weight.data += (self.lora_B @ self.lora_A) * self.scaling - self.merged = True - - def forward(self, x: torch.Tensor): - if self.r > 0 and not self.merged: - result = nn.Embedding.forward(self, x) - if self.r > 0: - after_A = F.embedding(x, self.lora_A.T, self.padding_idx, - self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.sparse) - result += (after_A @ self.lora_B.T) * self.scaling - return result - else: - return nn.Embedding.forward(self, x) - - -class Linear(nn.Linear, LoRALayer): - # LoRA implemented in a dense layer - def __init__( - self, - in_features: int, - out_features: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, - # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - merge_weights: bool = True, - **kwargs): - nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoRALayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights) - - self.fan_in_fan_out = fan_in_fan_out - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) - self.lora_B = nn.Parameter( - self.weight.new_zeros((out_features, r))) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - if fan_in_fan_out: - self.weight.data = self.weight.data.T - - def reset_parameters(self): - nn.Linear.reset_parameters(self) - if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - def train(self, mode: bool = True): - - def T(w): - return w.T if self.fan_in_fan_out else w - - nn.Linear.train(self, mode) - self.lora_A.requires_grad = mode - self.lora_B.requires_grad = mode - if mode and self.merge_weights and self.merged: - # Make sure that the weights are not merged - if self.r > 0: - self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling - self.merged = False - if not mode and self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0: - self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - self.merged = True - - def eval(self): - - def T(w): - return w.T if self.fan_in_fan_out else w - - nn.Linear.eval(self) - self.lora_A.requires_grad = False - self.lora_B.requires_grad = False - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0: - self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - self.merged = True - - def forward(self, x: torch.Tensor): - - def T(w): - return w.T if self.fan_in_fan_out else w - - if self.r > 0 and not self.merged: - result = F.linear(x, T(self.weight), bias=self.bias) - if self.r > 0: - result += (self.lora_dropout(x) @ self.lora_A.T - @ self.lora_B.T) * self.scaling - return result - else: - return F.linear(x, T(self.weight), bias=self.bias) - - -class MergedLinear(nn.Linear, LoRALayer): - # LoRA implemented in a dense layer - def __init__(self, - in_features: int, - out_features: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0., - enable_lora: List[bool] = [False], - fan_in_fan_out: bool = False, - merge_weights: bool = True, - **kwargs): - nn.Linear.__init__(self, in_features, out_features, **kwargs) - LoRALayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights) - assert out_features % len(enable_lora) == 0, \ - 'The length of enable_lora must divide out_features' - self.enable_lora = enable_lora - self.fan_in_fan_out = fan_in_fan_out - # Actual trainable parameters - if r > 0 and any(enable_lora): - self.lora_A = nn.Parameter( - self.weight.new_zeros((r * sum(enable_lora), in_features))) - self.lora_B = nn.Parameter( - self.weight.new_zeros( - (out_features // len(enable_lora) * sum(enable_lora), - r))) # weights for Conv1D with groups=sum(enable_lora) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - # Compute the indices - self.lora_ind = self.weight.new_zeros( - (out_features, ), dtype=torch.bool).view(len(enable_lora), -1) - self.lora_ind[enable_lora, :] = True - self.lora_ind = self.lora_ind.view(-1) - self.reset_parameters() - if fan_in_fan_out: - self.weight.data = self.weight.data.T - - def reset_parameters(self): - nn.Linear.reset_parameters(self) - if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - def zero_pad(self, x): - result = x.new_zeros((*x.shape[:-1], self.out_features)) - result = result.view(-1, self.out_features) - result[:, self.lora_ind] = x.reshape( - -1, - self.out_features // len(self.enable_lora) * sum(self.enable_lora)) - return result.view((*x.shape[:-1], self.out_features)) - - def train(self, mode: bool = True): - - def T(w): - return w.T if self.fan_in_fan_out else w - - nn.Linear.train(self, mode) - self.lora_A.requires_grad = mode - self.lora_B.requires_grad = mode - if mode and self.merge_weights and self.merged: - # Make sure that the weights are not merged - if self.r > 0 and any(self.enable_lora): - delta_w = F.conv1d( - self.lora_A.data.unsqueeze(0), - self.lora_B.data.unsqueeze(-1), - groups=sum(self.enable_lora)).squeeze(0) - self.weight.data -= self.zero_pad(T(delta_w * self.scaling)) - self.merged = False - if not mode and self.merge_weights and not self.merged: - if self.r > 0 and any(self.enable_lora): - delta_w = F.conv1d( - self.lora_A.data.unsqueeze(0), - self.lora_B.data.unsqueeze(-1), - groups=sum(self.enable_lora)).squeeze(0) - self.weight.data += self.zero_pad(T(delta_w * self.scaling)) - self.merged = True - - def eval(self): - - def T(w): - return w.T if self.fan_in_fan_out else w - - nn.Linear.eval(self) - self.lora_A.requires_grad = False - self.lora_B.requires_grad = False - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0 and any(self.enable_lora): - delta_w = F.conv1d( - self.lora_A.data.unsqueeze(0), - self.lora_B.data.unsqueeze(-1), - groups=sum(self.enable_lora)).squeeze(0) - self.weight.data += self.zero_pad(T(delta_w * self.scaling)) - self.merged = True - - def forward(self, x: torch.Tensor): - - def T(w): - return w.T if self.fan_in_fan_out else w - - if self.merged: - return F.linear(x, T(self.weight), bias=self.bias) - else: - result = F.linear(x, T(self.weight), bias=self.bias) - if self.r > 0: - after_A = F.linear(self.lora_dropout(x), self.lora_A) - after_B = F.conv1d( - after_A.transpose(-2, -1), - self.lora_B.unsqueeze(-1), - groups=sum(self.enable_lora)).transpose(-2, -1) - result += self.zero_pad(after_B) * self.scaling - return result - - -class Conv2d(nn.Conv2d, LoRALayer): - # LoRA implemented in a dense layer - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: int, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0., - merge_weights: bool = True, - **kwargs): - nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, - **kwargs) - LoRALayer.__init__( - self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights) - assert type(kernel_size) is int - # Actual trainable parameters - if r > 0: - self.lora_A = nn.Parameter( - self.weight.new_zeros( - (r * kernel_size, in_channels * kernel_size))) - self.lora_B = nn.Parameter( - self.weight.new_zeros( - (out_channels * kernel_size, r * kernel_size))) - self.scaling = self.lora_alpha / self.r - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - self.reset_parameters() - - def reset_parameters(self): - nn.Conv2d.reset_parameters(self) - if hasattr(self, 'lora_A'): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - - def train(self, mode: bool = True): - nn.Conv2d.train(self, mode) - self.lora_A.requires_grad = mode - self.lora_B.requires_grad = mode - if mode and self.merge_weights and self.merged: - # Make sure that the weights are not merged - self.weight.data -= (self.lora_B @ self.lora_A).view( - self.weight.shape) * self.scaling - self.merged = False - if not mode and self.merge_weights and not self.merged: - self.weight.data += (self.lora_B @ self.lora_A).view( - self.weight.shape) * self.scaling - self.merged = True - - def eval(self): - nn.Conv2d.eval(self) - self.lora_A.requires_grad = False - self.lora_B.requires_grad = False - if self.merge_weights and not self.merged: - # Merge the weights and mark it - self.weight.data += (self.lora_B @ self.lora_A).view( - self.weight.shape) * self.scaling - self.merged = True - - def forward(self, x: torch.Tensor): - if self.r > 0 and not self.merged: - return F.conv2d( - x, - self.weight + # noqa - (self.lora_B @ self.lora_A).view(self.weight.shape) # noqa - * self.scaling, - self.bias, - self.stride, - self.padding, - self.dilation, - self.groups) - return nn.Conv2d.forward(self, x) - - -def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: - for n, p in model.named_parameters(): - if 'lora_' not in n: - p.requires_grad = False - if bias == 'none': - return - elif bias == 'all': - for n, p in model.named_parameters(): - if 'bias' in n: - p.requires_grad = True - elif bias == 'lora_only': - for m in model.modules(): - if isinstance(m, LoRALayer) and \ - hasattr(m, 'bias') and \ - m.bias is not None: - m.bias.requires_grad = True - else: - raise NotImplementedError - - -def lora_state_dict(state_dict, bias: str = 'none') -> Dict[str, torch.Tensor]: - if bias == 'none': - return {k: state_dict[k] for k in state_dict if 'lora_' in k} - elif bias == 'all': - return { - k: state_dict[k] - for k in state_dict if 'lora_' in k or 'bias' in k - } - elif bias == 'lora_only': - to_return = {} - for k in state_dict: - if 'lora_' in k: - to_return[k] = state_dict[k] - bias_name = k.split('lora_')[0] + 'bias' - if bias_name in state_dict: - to_return[bias_name] = state_dict[bias_name] - return to_return - else: - raise NotImplementedError diff --git a/modelscope/swift/optimizers/__init__.py b/modelscope/swift/optimizers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/modelscope/swift/prompt.py b/modelscope/swift/prompt.py deleted file mode 100644 index 9fa207be..00000000 --- a/modelscope/swift/prompt.py +++ /dev/null @@ -1,241 +0,0 @@ -import os -import re -import types -from dataclasses import dataclass, field -from typing import Union - -import torch -from torch import nn - -from modelscope import snapshot_download -from modelscope.utils.constant import ModelFile -from .base import SwiftConfig - - -@dataclass -class PromptConfig(SwiftConfig): - """ - The configuration class for the prompt module. - - Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens - and prepend to the original tokens in the first layer or multiple layers. - 'Visual Prompt Tuning' by Jia et al.(2022) - See https://arxiv.org/abs/2203.12119 - - Here we apply the VPT to other fields. - - Args: - dim: The dimension of the hidden states - module_layer_name: The layer module to be replaced, in regex format - embedding_pos: The position of the embedding tensor - attention_mask_pos: The position of the attention mask - attention_mask_value: The value to pad to the attention mask - prompt_length: The length of the prompt tokens - only_prompt_trainable: Whether to train only prompt - attach_front: When set to True, prompt is attached in front of the embedding - extract_embedding: Whether the embedding is extracted at final stage to keep the same dims with inputs - pretrained_weights: The pretrained prompt weights. Can be a local dir, local file, - or a model id from modelscope - """ - - dim: int = field(metadata={'help': 'The dimension of the hidden states'}) - - module_layer_name: str = field( - metadata={'help': 'The layer module to be replaced, in regex format'}) - - embedding_pos: Union[str, int] = field( - metadata={'help': 'The position of the embedding tensor'}) - - attention_mask_pos: Union[str, int] = field( - default=None, metadata={'help': 'The position of the attention mask'}) - - attention_mask_value: Union[float, int, bool] = field( - default=0., - metadata={'help': 'The value to pad to the attention mask'}) - - prompt_length: int = field( - default=16, metadata={'help': 'The length of the prompt tokens'}) - - only_prompt_trainable: bool = field( - default=True, metadata={'help': 'Whether to train only prompt'}) - - attach_front: bool = field( - default=True, - metadata={ - 'help': - 'When set to True, prompt is attached in front of the embedding' - }) - - extract_embedding: bool = field( - default=False, - metadata={ - 'help': - 'Whether the embedding is extracted at final stage to keep the same dims with inputs' - }) - - pretrained_weights: str = field( - default=None, - metadata={ - 'help': - 'The pretrained prompt weights. Can be a local dir, local file, or a model id from modelscope' - }) - - -class Prompt: - - @staticmethod - def prepare_model(model: nn.Module, config: PromptConfig): - module_keys = [key for key, _ in model.named_modules()] - match_module_keys = [] - for module_key in module_keys: - if re.fullmatch(config.module_layer_name, module_key): # noqa - module = model.get_submodule(module_key) - - def _forward(self, *args, **kwargs): - if isinstance(config.embedding_pos, int): - input_embedding = args[config.embedding_pos] - else: - input_embedding = kwargs[config.embedding_pos] - - input_embedding = getattr( - self, 'prompt').forward(input_embedding) - if isinstance(config.embedding_pos, int): - args = type(args)( - args[0:config.embedding_pos] + (input_embedding, ) - + args[config.embedding_pos + 1:]) - else: - kwargs[config.embedding_pos] = input_embedding - - if config.attention_mask_pos: - attention_mask = None - if isinstance(config.attention_mask_pos, int): - attention_mask = args[config.attention_mask_pos] - elif isinstance(config.attention_mask_pos, str): - attention_mask = kwargs[config.attention_mask_pos] - - if attention_mask is not None: - attention_mask = getattr( - self, - 'prompt').patch_attention_mask(attention_mask) - if isinstance(config.attention_mask_pos, int): - args = type(args)( - args[0:config.attention_mask_pos] - + (attention_mask, ) - + args[config.attention_mask_pos + 1:]) - else: - kwargs[config.attention_mask_pos] = attention_mask - - forward_output = self.forward_origin(*args, **kwargs) - if config.extract_embedding: - forward_output = getattr( - self, 'prompt').extract(forward_output) - - return forward_output - - module.forward_origin = module.forward - module.forward = types.MethodType(_forward, module) - - if isinstance(config.dim, list): - input_dim = config.dim[len(match_module_keys)] - else: - input_dim = config.dim - - prompt_module = PromptModule(input_dim, - int(module_key.rsplit('.')[-1]), - config.prompt_length, - config.attention_mask_value, - config.attach_front) - setattr(module, 'prompt', prompt_module) - match_module_keys.append(module_key) - - if config.only_prompt_trainable: - for n, p in model.named_parameters(): - if 'prompt' not in n: - p.requires_grad = False - - def state_dict_hook(module, destination, prefix, local_metadata): - return { - key: value - for key, value in destination.items() if 'prompt' in key - } - - model.state_dict_hook_handle = model._register_state_dict_hook( - state_dict_hook) - - def load_state_dict(self, state_dict, strict=True): - return self.load_state_dict_origin(state_dict, False) - - model.load_state_dict_origin = model.load_state_dict - model.load_state_dict = types.MethodType(load_state_dict, model) - - if config.pretrained_weights is not None: - if not os.path.exists(config.pretrained_weights): - model_dir = snapshot_download(config.pretrained_weights) - pretrained_weights = os.path.join( - model_dir, ModelFile.TORCH_MODEL_BIN_FILE) - elif os.path.isfile(config.pretrained_weights): - pretrained_weights = config.pretrained_weights - else: - pretrained_weights = os.path.join( - config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE) - model.load_state_dict(torch.load(pretrained_weights)) - return model - - -class PromptModule(nn.Module): - """The implementation of vision prompt tuning method. - - Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens - and prepend to the original tokens in the first layer or multiple layers. - 'Visual Prompt Tuning' by Jia et al.(2022) - See https://arxiv.org/abs/2203.12119 - - Attributes: - dim: An integer indicating the embedding dimension. - layer_num: An integer indicating number of layers. - prompt_length: An integer indicating the length of vision prompt tuning. - """ - - def __init__(self, - dim, - layer_num, - prompt_length=None, - mask_values=0., - attach_front=True): - super(PromptModule, self).__init__() - self.dim = dim - self.layer_num = layer_num - self.prompt_length = prompt_length - self.mask_values = mask_values - self.attach_front = attach_front - - self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim)) - nn.init.xavier_uniform_(self.prompt_token) - - def forward(self, x): - prompt_token = self.prompt_token.expand(x.shape[0], -1, -1) - - if self.layer_num == 0: - if self.attach_front: - x = torch.cat((prompt_token, x), dim=1) - else: - x = torch.cat((x, prompt_token), dim=1) - else: - if self.attach_front: - x = torch.cat((prompt_token, x[:, self.prompt_length:, :]), - dim=1) - else: - x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), - dim=1) - return x - - def patch_attention_mask(self, m): - prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), - self.mask_values).to(m.device) - return torch.cat((prefix_attention_mask, m), dim=-1) - - def extract(self, x): - if self.attach_front: - return x[:, self.prompt_length:, :] - else: - return x[:, :-self.prompt_length, :] diff --git a/modelscope/trainers/cv/vision_efficient_tuning_trainer.py b/modelscope/trainers/cv/vision_efficient_tuning_trainer.py index 4c7dca73..ac32cc89 100644 --- a/modelscope/trainers/cv/vision_efficient_tuning_trainer.py +++ b/modelscope/trainers/cv/vision_efficient_tuning_trainer.py @@ -6,9 +6,7 @@ from torch import nn from modelscope.metainfo import Trainers from modelscope.models.base import Model, TorchModel from modelscope.trainers.builder import TRAINERS -from modelscope.trainers.default_config import merge_hooks from modelscope.trainers.trainer import EpochBasedTrainer -from modelscope.utils.constant import ModeKeys @TRAINERS.register_module(module_name=Trainers.vision_efficient_tuning) diff --git a/modelscope/trainers/nlp/siamese_uie_trainer.py b/modelscope/trainers/nlp/siamese_uie_trainer.py index e3289976..782fd360 100644 --- a/modelscope/trainers/nlp/siamese_uie_trainer.py +++ b/modelscope/trainers/nlp/siamese_uie_trainer.py @@ -328,7 +328,8 @@ class SiameseUIETrainer(EpochBasedTrainer): Example: {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} """ - pipeline_uie = pipeline(Tasks.siamese_uie, self.model) + pipeline_uie = pipeline( + Tasks.siamese_uie, self.model, device=self.device) if checkpoint_path is not None and os.path.isfile(checkpoint_path): from modelscope.trainers.hooks import LoadCheckpointHook LoadCheckpointHook.load_checkpoint(checkpoint_path, self) diff --git a/modelscope/trainers/optimizer/__init__.py b/modelscope/trainers/optimizer/__init__.py index cd59c072..9962c2c2 100644 --- a/modelscope/trainers/optimizer/__init__.py +++ b/modelscope/trainers/optimizer/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from modelscope.swift import ChildTuningAdamW from .builder import OPTIMIZERS, build_optimizer +from .child_tuning_adamw_optimizer import ChildTuningAdamW __all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW'] diff --git a/modelscope/swift/optimizers/child_tuning_adamw_optimizer.py b/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py similarity index 100% rename from modelscope/swift/optimizers/child_tuning_adamw_optimizer.py rename to modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index f618ffc7..65c238da 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -28,7 +28,6 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \ from modelscope.msdatasets.ms_dataset import MsDataset from modelscope.outputs import ModelOutputBase from modelscope.preprocessors.base import Preprocessor -from modelscope.swift import Swift from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.priority import Priority, get_priority from modelscope.trainers.lrscheduler.builder import build_lr_scheduler @@ -41,6 +40,7 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, from modelscope.utils.data_utils import to_device from modelscope.utils.device import create_device from modelscope.utils.file_utils import func_receive_dict_inputs +from modelscope.utils.import_utils import is_swift_available from modelscope.utils.logger import get_logger from modelscope.utils.registry import build_from_cfg from modelscope.utils.torch_utils import (compile_model, get_dist_info, @@ -54,6 +54,8 @@ from .hooks.hook import Hook from .parallel.builder import build_parallel from .parallel.utils import is_parallel +TunerConfig = Union['swift.SwiftConfig', 'swift.PeftConfig'] + @TRAINERS.register_module(module_name=Trainers.default) class EpochBasedTrainer(BaseTrainer): @@ -118,7 +120,8 @@ class EpochBasedTrainer(BaseTrainer): seed: int = 42, callbacks: Optional[List[Hook]] = None, samplers: Optional[Union[Sampler, Dict[str, Sampler]]] = None, - efficient_tuners: List[Dict] = None, + efficient_tuners: Union[Dict[str, TunerConfig], + TunerConfig] = None, **kwargs): self._seed = seed @@ -270,8 +273,12 @@ class EpochBasedTrainer(BaseTrainer): def tune_module(self, efficient_tuners): if efficient_tuners is not None: - for tuner in efficient_tuners: - self.model = Swift.prepare_model(self.model, tuner) + if not is_swift_available(): + raise ValueError( + 'Please install swift by `pip install ms-swift` to use efficient_tuners.' + ) + from swift import Swift + self.model = Swift.prepare_model(self.model, efficient_tuners) def place_model(self): """Place model to device, or to DDP diff --git a/modelscope/tuners/sd_lora.py b/modelscope/tuners/sd_lora.py deleted file mode 100644 index feff05f4..00000000 --- a/modelscope/tuners/sd_lora.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved. -# The implementation is adopted from HighCWu, -# made pubicly available under the Apache License 2.0 License at https://github.com/HighCWu/ControlLoRA -import os -from dataclasses import dataclass -from typing import List, Tuple, Union - -import torch -import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer -from diffusers.models.modeling_utils import ModelMixin -from diffusers.utils.outputs import BaseOutput - - -@dataclass -class TunerOutput(BaseOutput): - lora_states: Tuple[torch.FloatTensor] - - -class LoRACrossAttnProcessor(nn.Module): - """ The implementation of lora attention module. - """ - - def __init__(self, - hidden_size, - cross_attention_dim=None, - rank=4, - post_add=False, - key_states_skipped=False, - value_states_skipped=False, - output_states_skipped=False): - """ Initialize a lora attn instance. - Args: - hidden_size (`int`): The number of channels in embedding. - cross_attention_dim (`int`, *optional*): - The number of channels in the hidden_states. If not given, defaults to `hidden_size`. - rank (`int`, *optional*, defaults to 4): The number of rank of lora. - post_add (`bool`, *optional*, defaults to False): Set to `True`, conduct weighted - adding operation after lora. - key_states_skipped (`bool`, *optional*, defaults to False): - Set to `True` for skip to perform lora on key value. - value_states_skipped (`bool`, *optional*, defaults to False): - Set to `True` for skip to perform lora on value. - output_states_skipped (`bool`, *optional*, defaults to False): - Set to `True` for skip to perform lora on output value. - """ - super().__init__() - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.rank = rank - self.post_add = post_add - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - if not key_states_skipped: - self.to_k_lora = LoRALinearLayer( - hidden_size if post_add else - (cross_attention_dim or hidden_size), hidden_size, rank) - if not value_states_skipped: - self.to_v_lora = LoRALinearLayer( - hidden_size if post_add else - (cross_attention_dim or hidden_size), hidden_size, rank) - if not output_states_skipped: - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - - self.key_states_skipped: bool = key_states_skipped - self.value_states_skipped: bool = value_states_skipped - self.output_states_skipped: bool = output_states_skipped - - def skip_key_states(self, is_skipped: bool = True): - if not is_skipped: - assert hasattr(self, 'to_k_lora') - self.key_states_skipped = is_skipped - - def skip_value_states(self, is_skipped: bool = True): - if not is_skipped: - assert hasattr(self, 'to_q_lora') - self.value_states_skipped = is_skipped - - def skip_output_states(self, is_skipped: bool = True): - if not is_skipped: - assert hasattr(self, 'to_out_lora') - self.output_states_skipped = is_skipped - - def __call__(self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - scale=1.0): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask( - attention_mask=attention_mask, - target_length=sequence_length, - batch_size=batch_size) - - query = attn.to_q(hidden_states) - query = query + scale * self.to_q_lora( - query if self.post_add else hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - - key = attn.to_k(encoder_hidden_states) - if not self.key_states_skipped: - key = key + scale * self.to_k_lora( - key if self.post_add else encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - if not self.value_states_skipped: - value = value + scale * self.to_v_lora( - value if self.post_add else encoder_hidden_states) - - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - out = attn.to_out[0](hidden_states) - if not self.output_states_skipped: - out = out + scale * self.to_out_lora( - out if self.post_add else hidden_states) - hidden_states = out - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class LoRATuner(ModelMixin, ConfigMixin): - - @staticmethod - def tune( - model: nn.Module, - tuner_config=None, - pretrained_tuner=None, - ): - tuner = LoRATuner.from_config(tuner_config) - if pretrained_tuner is not None and os.path.exists(pretrained_tuner): - tuner.load_state_dict( - torch.load(pretrained_tuner, map_location='cpu'), strict=True) - tune_layers_list = list( - [list(layer_list) for layer_list in tuner.lora_layers]) - assert hasattr(model, 'unet') - unet = model.unet - tuner.to(unet.device) - tune_attn_procs = tuner.set_tune_layers(unet, tune_layers_list) - unet.set_attn_processor(tune_attn_procs) - return tuner - - def set_tune_layers(self, unet, tune_layers_list): - n_ch = len(unet.config.block_out_channels) - control_ids = [i for i in range(n_ch)] - tune_attn_procs = {} - - for name in unet.attn_processors.keys(): - if name.startswith('mid_block'): - control_id = control_ids[-1] - elif name.startswith('up_blocks'): - block_id = int(name[len('up_blocks.')]) - control_id = list(reversed(control_ids))[block_id] - elif name.startswith('down_blocks'): - block_id = int(name[len('down_blocks.')]) - control_id = control_ids[block_id] - - tune_layers = tune_layers_list[control_id] - if len(tune_layers) != 0: - tune_layer = tune_layers.pop(0) - tune_attn_procs[name] = tune_layer - return tune_attn_procs - - @register_to_config - def __init__( - self, - lora_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - lora_cross_attention_dims: Tuple[List[int]] = ([ - None, 768, None, 768, None, 768, None, 768, None, 768 - ], [None, 768, None, 768, None, 768, None, 768, None, - 768], [None, 768, None, 768, None, 768, None, 768, None, - 768], [None, 768]), - lora_rank: int = 4, - lora_post_add: bool = False, - lora_key_states_skipped: bool = False, - lora_value_states_skipped: bool = False, - lora_output_states_skipped: bool = False, - ): - super().__init__() - - lora_cls = LoRACrossAttnProcessor - - self.lora_layers = nn.ModuleList([]) - - for i, lora_cross_attention_dim in enumerate( - lora_cross_attention_dims): - self.lora_layers.append( - nn.ModuleList([ - lora_cls( - lora_block_out_channels[i], - cross_attention_dim=cross_attention_dim, - rank=lora_rank, - post_add=lora_post_add, - key_states_skipped=lora_key_states_skipped, - value_states_skipped=lora_value_states_skipped, - output_states_skipped=lora_output_states_skipped) - for cross_attention_dim in lora_cross_attention_dim - ])) - - def forward(self) -> Union[TunerOutput, Tuple]: - lora_states_list = [] - tune_layers_list = list( - [list(layer_list) for layer_list in self.lora_layers]) - for tune_list in tune_layers_list: - for tune_layer in tune_list: - lora_states_list.append(tune_layer.to_q_lora.down.weight) - return TunerOutput(lora_states=tuple(lora_states_list)) diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 88f256a7..2ce9d55d 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -246,6 +246,10 @@ def is_wenetruntime_available(): return importlib.util.find_spec('wenetruntime') is not None +def is_swift_available(): + return importlib.util.find_spec('swift') is not None + + def is_tf_available(): return _tf_available diff --git a/requirements/framework.txt b/requirements/framework.txt index 83e69a00..e9dc08c4 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -4,6 +4,7 @@ datasets>=2.8.0,<=2.13.0 einops filelock>=3.3.0 gast>=0.2.2 +ms-swift numpy oss2 pandas diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index ef0ccae7..f9a6ec66 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -8,12 +8,12 @@ from modelscope.metainfo import Preprocessors, Trainers from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.swift.optimizers.child_tuning_adamw_optimizer import \ - calculate_fisher from modelscope.trainers import build_trainer from modelscope.trainers.hooks import Hook from modelscope.trainers.nlp_trainer import (EpochBasedTrainer, NlpEpochBasedTrainer) +from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ + calculate_fisher from modelscope.trainers.training_args import TrainingArgs from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.data_utils import to_device diff --git a/tests/trainers/test_finetune_vision_efficient_tuning_swift.py b/tests/trainers/test_finetune_vision_efficient_tuning_swift.py index d8733024..56a5b6fc 100644 --- a/tests/trainers/test_finetune_vision_efficient_tuning_swift.py +++ b/tests/trainers/test_finetune_vision_efficient_tuning_swift.py @@ -6,11 +6,8 @@ import unittest from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset -from modelscope.swift import Swift -from modelscope.swift.adapter import AdapterConfig -from modelscope.swift.lora import LoRAConfig -from modelscope.swift.prompt import PromptConfig from modelscope.trainers import build_trainer +from modelscope.utils.import_utils import is_swift_available from modelscope.utils.test_utils import test_level @@ -43,8 +40,10 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0 and is_swift_available(), + 'skip test in current test level') def test_vision_efficient_tuning_swift_lora_train(self): + from swift import LoRAConfig model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora' def cfg_modify_fn(cfg): @@ -56,10 +55,9 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): return cfg lora_config = LoRAConfig( - rank=self.tune_length, - replace_modules=['qkv'], + r=self.tune_length, + target_modules=['qkv'], merge_weights=False, - only_lora_trainable=False, use_merged_linear=True, enable_lora=[True]) @@ -69,7 +67,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, cfg_modify_fn=cfg_modify_fn, - efficient_tuners=[lora_config]) + efficient_tuners=lora_config) trainer = build_trainer( name=Trainers.vision_efficient_tuning, default_args=kwargs) @@ -82,8 +80,10 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0 and is_swift_available(), + 'skip test in current test level') def test_vision_efficient_tuning_swift_adapter_train(self): + from swift import AdapterConfig model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter' def cfg_modify_fn(cfg): @@ -97,9 +97,8 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): adapter_config = AdapterConfig( dim=768, hidden_pos=0, - module_name=r'.*blocks\.\d+\.mlp$', - adapter_length=self.tune_length, - only_adapter_trainable=False) + target_modules=r'.*blocks\.\d+\.mlp$', + adapter_length=self.tune_length) kwargs = dict( model=model_id, @@ -107,7 +106,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, cfg_modify_fn=cfg_modify_fn, - efficient_tuners=[adapter_config]) + efficient_tuners=adapter_config) trainer = build_trainer( name=Trainers.vision_efficient_tuning, default_args=kwargs) @@ -120,8 +119,10 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0 and is_swift_available(), + 'skip test in current test level') def test_vision_efficient_tuning_swift_prompt_train(self): + from swift import PromptConfig model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt' def cfg_modify_fn(cfg): @@ -134,10 +135,9 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): prompt_config = PromptConfig( dim=768, - module_layer_name=r'.*blocks\.\d+$', + target_modules=r'.*blocks\.\d+$', embedding_pos=0, prompt_length=self.tune_length, - only_prompt_trainable=False, attach_front=False) kwargs = dict( @@ -146,7 +146,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, cfg_modify_fn=cfg_modify_fn, - efficient_tuners=[prompt_config]) + efficient_tuners=prompt_config) trainer = build_trainer( name=Trainers.vision_efficient_tuning, default_args=kwargs) diff --git a/tests/tuners/__init__.py b/tests/tuners/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/tuners/test_adapter.py b/tests/tuners/test_adapter.py deleted file mode 100644 index a110591a..00000000 --- a/tests/tuners/test_adapter.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -import shutil -import tempfile -import unittest - -import numpy as np -import torch - -from modelscope import read_config -from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models.base import Model -from modelscope.msdatasets import MsDataset -from modelscope.pipelines import pipeline -from modelscope.swift import Swift -from modelscope.swift.adapter import AdapterConfig -from modelscope.trainers import build_trainer -from modelscope.utils.constant import ModelFile, Tasks -from modelscope.utils.test_utils import test_level - - -class TestAdapter(unittest.TestCase): - - def setUp(self): - print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) - self.tmp_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(self.tmp_dir): - os.makedirs(self.tmp_dir) - - def tearDown(self): - shutil.rmtree(self.tmp_dir) - super().tearDown() - - @unittest.skipUnless(test_level() >= 0, 'skip in this level') - def test_adapter_smoke_test(self): - dataset = MsDataset.load( - 'clue', subset_name='afqmc', - split='train').to_hf_dataset().select(range(2)) - - model_dir = snapshot_download( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - model = Model.from_pretrained(model_dir, adv_grad_factor=None) - - cfg_file = os.path.join(model_dir, 'configuration.json') - - model_cfg = os.path.join(model_dir, 'config.json') - model_cfg = read_config(model_cfg) - - adapter_config = AdapterConfig( - dim=model_cfg.hidden_size, - module_name=r'.*layer\.\d+$', - method_name='feed_forward_chunk', - hidden_pos=0) - model = Swift.prepare_model(model, adapter_config) - kwargs = dict( - model=model, - cfg_file=cfg_file, - train_dataset=dataset, - eval_dataset=dataset, - work_dir=self.tmp_dir) - - trainer = build_trainer(default_args=kwargs) - trainer.train() - output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) - - def pipeline_sentence_similarity(model_dir): - model = Model.from_pretrained(model_dir) - adapter_config.pretrained_weights = output_dir - Swift.prepare_model(model, adapter_config) - model.eval() - pipeline_ins = pipeline( - task=Tasks.sentence_similarity, model=model) - return pipeline_ins(input=('test', 'this is a test')) - - output1 = pipeline_sentence_similarity( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - print(output1) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/tuners/test_lora.py b/tests/tuners/test_lora.py deleted file mode 100644 index b3238dad..00000000 --- a/tests/tuners/test_lora.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -import shutil -import tempfile -import unittest - -import numpy as np -import torch - -from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models.base import Model -from modelscope.msdatasets import MsDataset -from modelscope.pipelines import pipeline -from modelscope.swift import Swift -from modelscope.swift.lora import (Linear, LoRA, LoRAConfig, - mark_only_lora_as_trainable) -from modelscope.trainers import build_trainer -from modelscope.utils.constant import ModelFile, Tasks -from modelscope.utils.test_utils import test_level - - -class TestLora(unittest.TestCase): - - def setUp(self): - print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) - self.tmp_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(self.tmp_dir): - os.makedirs(self.tmp_dir) - - def tearDown(self): - shutil.rmtree(self.tmp_dir) - super().tearDown() - - @unittest.skipUnless(test_level() >= 0, 'skip in this level') - def test_lora_base(self): - - class TestModel(torch.nn.Module): - - def __init__(self): - super().__init__() - self.lora = Linear(16, 16, r=4) - - model = TestModel() - mark_only_lora_as_trainable(model) - model.train() - loss = model.lora(torch.ones(16, 16)) - loss = loss.sum() - loss.backward() - - model = TestModel() - mark_only_lora_as_trainable(model) - model.eval() - loss = model.lora(torch.ones(16, 16)) - loss = loss.sum() - try: - loss.backward() - except Exception: - pass - else: - raise Exception('No tensor needs grad, should throw en error here') - - @unittest.skipUnless(test_level() >= 0, 'skip in this level') - def test_lora_smoke_test(self): - dataset = MsDataset.load( - 'clue', subset_name='afqmc', - split='train').to_hf_dataset().select(range(2)) - - model_dir = snapshot_download( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - model = Model.from_pretrained(model_dir, adv_grad_factor=None) - - cfg_file = os.path.join(model_dir, 'configuration.json') - lora_config = LoRAConfig(replace_modules=['query', 'key', 'value']) - model = Swift.prepare_model(model, lora_config) - - kwargs = dict( - model=model, - cfg_file=cfg_file, - train_dataset=dataset, - eval_dataset=dataset, - work_dir=self.tmp_dir) - - trainer = build_trainer(default_args=kwargs) - trainer.train() - output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) - - def pipeline_sentence_similarity(model_dir): - model = Model.from_pretrained(model_dir) - lora_config.pretrained_weights = output_dir - Swift.prepare_model(model, lora_config) - model.load_state_dict( - torch.load(os.path.join(output_dir, 'pytorch_model.bin'))) - model.eval() - pipeline_ins = pipeline( - task=Tasks.sentence_similarity, model=model) - return pipeline_ins(input=('test', 'this is a test')) - - output1 = pipeline_sentence_similarity( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - - LoRA.unpatch_lora(model, lora_config) - model.save_pretrained( - output_dir, save_checkpoint_names='pytorch_model.bin') - - def pipeline_sentence_similarity_origin(): - model = Model.from_pretrained(output_dir) - model.eval() - pipeline_ins = pipeline( - task=Tasks.sentence_similarity, model=model) - return pipeline_ins(input=('test', 'this is a test')) - - output2 = pipeline_sentence_similarity_origin() - print(output1, output2) - self.assertTrue(all(np.isclose(output1['scores'], output2['scores']))) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/tuners/test_prompt.py b/tests/tuners/test_prompt.py deleted file mode 100644 index c338162f..00000000 --- a/tests/tuners/test_prompt.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os -import shutil -import tempfile -import unittest - -import numpy as np -import torch - -from modelscope import read_config -from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models.base import Model -from modelscope.msdatasets import MsDataset -from modelscope.pipelines import pipeline -from modelscope.swift import Swift -from modelscope.swift.adapter import AdapterConfig -from modelscope.swift.prompt import PromptConfig -from modelscope.trainers import build_trainer -from modelscope.utils.constant import ModelFile, Tasks -from modelscope.utils.test_utils import test_level - - -class TestPrompt(unittest.TestCase): - - def setUp(self): - print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) - self.tmp_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(self.tmp_dir): - os.makedirs(self.tmp_dir) - - def tearDown(self): - shutil.rmtree(self.tmp_dir) - super().tearDown() - - @unittest.skipUnless(test_level() >= 0, 'skip in this level') - def test_prompt_smoke_test(self): - dataset = MsDataset.load( - 'clue', subset_name='afqmc', - split='train').to_hf_dataset().select(range(2)) - - model_dir = snapshot_download( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - model = Model.from_pretrained(model_dir, adv_grad_factor=None) - - cfg_file = os.path.join(model_dir, 'configuration.json') - model_cfg = os.path.join(model_dir, 'config.json') - model_cfg = read_config(model_cfg) - - prompt_config = PromptConfig( - dim=model_cfg.hidden_size, - module_layer_name=r'.*layer\.\d+$', - embedding_pos=0, - attention_mask_pos=1) - - model = Swift.prepare_model(model, prompt_config) - - kwargs = dict( - model=model, - cfg_file=cfg_file, - train_dataset=dataset, - eval_dataset=dataset, - work_dir=self.tmp_dir) - - trainer = build_trainer(default_args=kwargs) - trainer.train() - output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) - - def pipeline_sentence_similarity(model_dir): - model = Model.from_pretrained(model_dir) - prompt_config.pretrained_weights = output_dir - Swift.prepare_model(model, prompt_config) - model.eval() - pipeline_ins = pipeline( - task=Tasks.sentence_similarity, model=model) - return pipeline_ins(input=('test', 'this is a test')) - - output1 = pipeline_sentence_similarity( - 'damo/nlp_structbert_sentence-similarity_chinese-tiny') - print(output1) - - -if __name__ == '__main__': - unittest.main()