mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Replace code with swift wheel (#467)
This commit is contained in:
@@ -3,14 +3,13 @@ import sys
|
|||||||
import types
|
import types
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from modelscope import (EpochBasedTrainer, MsDataset, TorchModel, TrainingArgs,
|
from modelscope import (EpochBasedTrainer, MsDataset, TorchModel, TrainingArgs,
|
||||||
build_dataset_from_file, snapshot_download)
|
build_dataset_from_file, snapshot_download)
|
||||||
from modelscope.metainfo import Trainers
|
from modelscope.metainfo import Trainers
|
||||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
|
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.trainers import build_trainer
|
from modelscope.trainers import build_trainer
|
||||||
|
|
||||||
DEFAULT_PAD_TOKEN = '[PAD]'
|
DEFAULT_PAD_TOKEN = '[PAD]'
|
||||||
@@ -205,12 +204,12 @@ preprocessor = TextGenerationTransformersPreprocessor(
|
|||||||
|
|
||||||
if args.use_lora != 0:
|
if args.use_lora != 0:
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
replace_modules=['pack'],
|
target_modules=['pack'],
|
||||||
rank=args.lora_rank,
|
r=args.lora_rank,
|
||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
lora_dropout=args.lora_dropout)
|
lora_dropout=args.lora_dropout)
|
||||||
model = model.bfloat16()
|
model = model.bfloat16()
|
||||||
Swift.prepare_model(model, lora_config)
|
model = Swift.prepare_model(model, lora_config)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
|
|
||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
# 使用源模型 model_id 初始化 pipeline
|
# 使用源模型 model_id 初始化 pipeline
|
||||||
@@ -12,11 +11,11 @@ model_id = 'baichuan-inc/baichuan-7B'
|
|||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task=Tasks.text_generation, model=model_id, model_revision='v1.0.2')
|
task=Tasks.text_generation, model=model_id, model_revision='v1.0.2')
|
||||||
# lora 配置,replace_modules,rank,alpha 需与训练参数相同
|
# lora 配置,replace_modules,rank,alpha 需与训练参数相同
|
||||||
lora_config = LoRAConfig(replace_modules=['pack'], rank=32, lora_alpha=32)
|
lora_config = LoRAConfig(target_modules=['pack'], r=32, lora_alpha=32)
|
||||||
# 转 bf16,需与训练精度相同
|
# 转 bf16,需与训练精度相同
|
||||||
model = pipe.model.bfloat16()
|
model = pipe.model.bfloat16()
|
||||||
# model 转 lora
|
# model 转 lora
|
||||||
Swift.prepare_model(model, lora_config)
|
model = Swift.prepare_model(model, lora_config)
|
||||||
# 加载 lora 参数,默认 link 到于 output/model 路径
|
# 加载 lora 参数,默认 link 到于 output/model 路径
|
||||||
work_dir = './tmp'
|
work_dir = './tmp'
|
||||||
state_dict = torch.load(osp.join(work_dir, 'output/pytorch_model.bin'))
|
state_dict = torch.load(osp.join(work_dir, 'output/pytorch_model.bin'))
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from chatglm_trainer import Seq2SeqTrainer
|
from chatglm_trainer import Seq2SeqTrainer
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
from text_generation_metric import TextGenerationMetric
|
from text_generation_metric import TextGenerationMetric
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
@@ -11,8 +12,6 @@ from modelscope import build_dataset_from_file, snapshot_download
|
|||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.models import Model
|
from modelscope.models import Model
|
||||||
from modelscope.msdatasets import MsDataset
|
from modelscope.msdatasets import MsDataset
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.trainers.training_args import TrainingArgs
|
from modelscope.trainers.training_args import TrainingArgs
|
||||||
from modelscope.utils.config import ConfigDict
|
from modelscope.utils.config import ConfigDict
|
||||||
from modelscope.utils.hub import read_config
|
from modelscope.utils.hub import read_config
|
||||||
@@ -243,15 +242,15 @@ elif not args.use_lora:
|
|||||||
|
|
||||||
if args.use_lora != 0:
|
if args.use_lora != 0:
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
replace_modules=['attention.query_key_value'],
|
target_modules=['attention.query_key_value'],
|
||||||
rank=args.lora_rank,
|
r=args.lora_rank,
|
||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
lora_dropout=args.lora_dropout)
|
lora_dropout=args.lora_dropout)
|
||||||
if args.use_amp:
|
if args.use_amp:
|
||||||
model = model.float()
|
model = model.float()
|
||||||
else:
|
else:
|
||||||
model = model.bfloat16()
|
model = model.bfloat16()
|
||||||
Swift.prepare_model(model, lora_config)
|
model = Swift.prepare_model(model, lora_config)
|
||||||
|
|
||||||
prefix = args.source_prefix if args.source_prefix is not None else ''
|
prefix = args.source_prefix if args.source_prefix is not None else ''
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
|
|
||||||
from modelscope import Model, pipeline, read_config
|
from modelscope import Model, pipeline, read_config
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.utils.config import ConfigDict
|
from modelscope.utils.config import ConfigDict
|
||||||
|
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
replace_modules=['attention.query_key_value'],
|
target_modules=['attention.query_key_value'],
|
||||||
rank=32,
|
r=32,
|
||||||
lora_alpha=32,
|
lora_alpha=32,
|
||||||
lora_dropout=0.05,
|
lora_dropout=0.05)
|
||||||
pretrained_weights='./lora_dureader_target/iter_600.pth')
|
|
||||||
|
|
||||||
model_dir = 'ZhipuAI/ChatGLM-6B'
|
model_dir = 'ZhipuAI/ChatGLM-6B'
|
||||||
model_config = read_config(model_dir)
|
model_config = read_config(model_dir)
|
||||||
@@ -19,8 +21,12 @@ model_config['model'] = ConfigDict({
|
|||||||
|
|
||||||
model = Model.from_pretrained(model_dir, cfg_dict=model_config)
|
model = Model.from_pretrained(model_dir, cfg_dict=model_config)
|
||||||
model = model.bfloat16()
|
model = model.bfloat16()
|
||||||
Swift.prepare_model(model, lora_config)
|
model = Swift.prepare_model(model, lora_config)
|
||||||
|
work_dir = './tmp'
|
||||||
|
state_dict = torch.load(osp.join(work_dir, 'iter_600.pth'))
|
||||||
|
model = Swift.from_pretrained(
|
||||||
|
model, osp.join(work_dir, 'output_best'), device_map='auto')
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
pipe = pipeline('chat', model, pipeline_name='chatglm6b-text-generation')
|
pipe = pipeline('chat', model, pipeline_name='chatglm6b-text-generation')
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
|
|
||||||
from modelscope import Model, pipeline, read_config
|
from modelscope import Model, pipeline, read_config
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.utils.config import ConfigDict
|
from modelscope.utils.config import ConfigDict
|
||||||
|
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
replace_modules=['attention.query_key_value'],
|
target_modules=['attention.query_key_value'],
|
||||||
rank=32,
|
r=32,
|
||||||
lora_alpha=32,
|
lora_alpha=32,
|
||||||
lora_dropout=0.05,
|
lora_dropout=0.05)
|
||||||
pretrained_weights='./lora_dureader_target/iter_600.pth')
|
|
||||||
|
|
||||||
model_dir = 'ZhipuAI/chatglm2-6b'
|
model_dir = 'ZhipuAI/chatglm2-6b'
|
||||||
model_config = read_config(model_dir)
|
model_config = read_config(model_dir)
|
||||||
@@ -19,7 +21,12 @@ model_config['model'] = ConfigDict({
|
|||||||
|
|
||||||
model = Model.from_pretrained(model_dir, cfg_dict=model_config)
|
model = Model.from_pretrained(model_dir, cfg_dict=model_config)
|
||||||
model = model.bfloat16()
|
model = model.bfloat16()
|
||||||
Swift.prepare_model(model, lora_config)
|
model = Swift.prepare_model(model, lora_config)
|
||||||
|
work_dir = './tmp'
|
||||||
|
state_dict = torch.load(osp.join(work_dir, 'iter_600.pth'))
|
||||||
|
model = Swift.from_pretrained(
|
||||||
|
model, osp.join(work_dir, 'output_best'), device_map='auto')
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
pipe = pipeline('chat', model, pipeline_name='chatglm2_6b-text-generation')
|
pipe = pipeline('chat', model, pipeline_name='chatglm2_6b-text-generation')
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
|
|
||||||
from modelscope import TrainingArgs
|
from modelscope import TrainingArgs
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
@@ -15,8 +16,6 @@ from modelscope.metainfo import Trainers
|
|||||||
from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer
|
from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer
|
||||||
from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
|
from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
|
||||||
TorchCustomDataset
|
TorchCustomDataset
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.trainers import build_trainer
|
from modelscope.trainers import build_trainer
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
@@ -255,12 +254,12 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if args.use_lora != 0:
|
if args.use_lora != 0:
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
replace_modules=['q_proj', 'k_proj', 'v_proj'],
|
target_modules=['q_proj', 'k_proj', 'v_proj'],
|
||||||
rank=args.lora_rank,
|
r=args.lora_rank,
|
||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
lora_dropout=args.lora_dropout)
|
lora_dropout=args.lora_dropout)
|
||||||
model = model.bfloat16()
|
model = model.bfloat16()
|
||||||
Swift.prepare_model(model, lora_config)
|
model = Swift.prepare_model(model, lora_config)
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ from functools import partial
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING, get_dataset,
|
from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING, get_dataset,
|
||||||
get_model_tokenizer, inference, parse_args, process_dataset,
|
get_model_tokenizer, inference, parse_args, process_dataset,
|
||||||
tokenize_function)
|
tokenize_function)
|
||||||
|
|
||||||
from modelscope import get_logger
|
from modelscope import get_logger
|
||||||
from modelscope.swift import LoRAConfig, Swift
|
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'This directory has been migrated to '
|
'This directory has been migrated to '
|
||||||
@@ -76,13 +76,15 @@ def llm_infer(args: InferArguments) -> None:
|
|||||||
# ### Preparing lora
|
# ### Preparing lora
|
||||||
if args.sft_type == 'lora':
|
if args.sft_type == 'lora':
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
replace_modules=args.lora_target_modules,
|
target_modules=args.lora_target_modules,
|
||||||
rank=args.lora_rank,
|
r=args.lora_rank,
|
||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
lora_dropout=args.lora_dropout_p,
|
lora_dropout=args.lora_dropout_p,
|
||||||
pretrained_weights=args.ckpt_path)
|
pretrained_weights=args.ckpt_path)
|
||||||
logger.info(f'lora_config: {lora_config}')
|
logger.info(f'lora_config: {lora_config}')
|
||||||
model = Swift.prepare_model(model, lora_config)
|
model = Swift.prepare_model(model, lora_config)
|
||||||
|
state_dict = torch.load(args.ckpt_path, map_location='cpu')
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
elif args.sft_type == 'full':
|
elif args.sft_type == 'full':
|
||||||
state_dict = torch.load(args.ckpt_path, map_location='cpu')
|
state_dict = torch.load(args.ckpt_path, map_location='cpu')
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from functools import partial
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from swift import LoRAConfig, Swift
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING,
|
from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING,
|
||||||
data_collate_fn, get_dataset, get_model_tokenizer,
|
data_collate_fn, get_dataset, get_model_tokenizer,
|
||||||
@@ -29,7 +30,6 @@ from utils import (DATASET_MAPPING, DEFAULT_PROMPT, MODEL_MAPPING,
|
|||||||
tokenize_function)
|
tokenize_function)
|
||||||
|
|
||||||
from modelscope import get_logger
|
from modelscope import get_logger
|
||||||
from modelscope.swift import LoRAConfig, Swift
|
|
||||||
from modelscope.trainers import EpochBasedTrainer
|
from modelscope.trainers import EpochBasedTrainer
|
||||||
from modelscope.utils.config import Config
|
from modelscope.utils.config import Config
|
||||||
|
|
||||||
@@ -141,8 +141,8 @@ def llm_sft(args: SftArguments) -> None:
|
|||||||
# ### Preparing lora
|
# ### Preparing lora
|
||||||
if args.sft_type == 'lora':
|
if args.sft_type == 'lora':
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
replace_modules=args.lora_target_modules,
|
target_modules=args.lora_target_modules,
|
||||||
rank=args.lora_rank,
|
r=args.lora_rank,
|
||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
lora_dropout=args.lora_dropout_p)
|
lora_dropout=args.lora_dropout_p)
|
||||||
logger.info(f'lora_config: {lora_config}')
|
logger.info(f'lora_config: {lora_config}')
|
||||||
|
|||||||
@@ -5,47 +5,33 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import json
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
#
|
#
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
from matplotlib.axes import Axes
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from numpy import ndarray
|
from swift import LoRAConfig, Swift
|
||||||
from tensorboard.backend.event_processing.event_accumulator import \
|
from tensorboard.backend.event_processing.event_accumulator import \
|
||||||
EventAccumulator
|
EventAccumulator
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import device as Device
|
from torch import device as Device
|
||||||
from torch import dtype as Dtype
|
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim import lr_scheduler as lrs
|
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
#
|
#
|
||||||
from torchmetrics import Accuracy, MeanMetric
|
from torchmetrics import Accuracy, MeanMetric
|
||||||
#
|
#
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
#
|
#
|
||||||
from modelscope import (Model, MsDataset, get_logger, read_config,
|
from modelscope import Model, MsDataset, get_logger, read_config
|
||||||
snapshot_download)
|
|
||||||
from modelscope.metrics.base import Metric
|
from modelscope.metrics.base import Metric
|
||||||
from modelscope.metrics.builder import METRICS
|
from modelscope.metrics.builder import METRICS
|
||||||
from modelscope.models.nlp.chatglm2 import ChatGLM2Tokenizer
|
from modelscope.models.nlp.chatglm2 import ChatGLM2Tokenizer
|
||||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||||
TorchCustomDataset
|
TorchCustomDataset
|
||||||
from modelscope.swift import LoRAConfig, Swift
|
from modelscope.utils.config import ConfigDict
|
||||||
from modelscope.trainers import EpochBasedTrainer
|
|
||||||
from modelscope.utils.config import Config, ConfigDict
|
|
||||||
from modelscope.utils.registry import default_group
|
from modelscope.utils.registry import default_group
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -209,8 +209,8 @@
|
|||||||
"LORA_ALPHA = 32\n",
|
"LORA_ALPHA = 32\n",
|
||||||
"LORA_DROPOUT_P = 0 # Arbitrary value\n",
|
"LORA_DROPOUT_P = 0 # Arbitrary value\n",
|
||||||
"lora_config = LoRAConfig(\n",
|
"lora_config = LoRAConfig(\n",
|
||||||
" replace_modules=LORA_TARGET_MODULES,\n",
|
" target_modules=LORA_TARGET_MODULES,\n",
|
||||||
" rank=LORA_RANK,\n",
|
" r=LORA_RANK,\n",
|
||||||
" lora_alpha=LORA_ALPHA,\n",
|
" lora_alpha=LORA_ALPHA,\n",
|
||||||
" lora_dropout=LORA_DROPOUT_P,\n",
|
" lora_dropout=LORA_DROPOUT_P,\n",
|
||||||
" pretrained_weights=CKPT_FAPTH)\n",
|
" pretrained_weights=CKPT_FAPTH)\n",
|
||||||
|
|||||||
@@ -224,8 +224,8 @@
|
|||||||
"LORA_ALPHA = 32\n",
|
"LORA_ALPHA = 32\n",
|
||||||
"LORA_DROPOUT_P = 0.1\n",
|
"LORA_DROPOUT_P = 0.1\n",
|
||||||
"lora_config = LoRAConfig(\n",
|
"lora_config = LoRAConfig(\n",
|
||||||
" replace_modules=LORA_TARGET_MODULES,\n",
|
" target_modules=LORA_TARGET_MODULES,\n",
|
||||||
" rank=LORA_RANK,\n",
|
" r=LORA_RANK,\n",
|
||||||
" lora_alpha=LORA_ALPHA,\n",
|
" lora_alpha=LORA_ALPHA,\n",
|
||||||
" lora_dropout=LORA_DROPOUT_P)\n",
|
" lora_dropout=LORA_DROPOUT_P)\n",
|
||||||
"logger.info(f'lora_config: {lora_config}')\n",
|
"logger.info(f'lora_config: {lora_config}')\n",
|
||||||
|
|||||||
@@ -212,8 +212,8 @@
|
|||||||
"LORA_ALPHA = 32\n",
|
"LORA_ALPHA = 32\n",
|
||||||
"LORA_DROPOUT_P = 0 # Arbitrary value\n",
|
"LORA_DROPOUT_P = 0 # Arbitrary value\n",
|
||||||
"lora_config = LoRAConfig(\n",
|
"lora_config = LoRAConfig(\n",
|
||||||
" replace_modules=LORA_TARGET_MODULES,\n",
|
" target_modules=LORA_TARGET_MODULES,\n",
|
||||||
" rank=LORA_RANK,\n",
|
" r=LORA_RANK,\n",
|
||||||
" lora_alpha=LORA_ALPHA,\n",
|
" lora_alpha=LORA_ALPHA,\n",
|
||||||
" lora_dropout=LORA_DROPOUT_P,\n",
|
" lora_dropout=LORA_DROPOUT_P,\n",
|
||||||
" pretrained_weights=CKPT_FAPTH)\n",
|
" pretrained_weights=CKPT_FAPTH)\n",
|
||||||
|
|||||||
@@ -234,8 +234,8 @@
|
|||||||
"LORA_ALPHA = 32\n",
|
"LORA_ALPHA = 32\n",
|
||||||
"LORA_DROPOUT_P = 0.1\n",
|
"LORA_DROPOUT_P = 0.1\n",
|
||||||
"lora_config = LoRAConfig(\n",
|
"lora_config = LoRAConfig(\n",
|
||||||
" replace_modules=LORA_TARGET_MODULES,\n",
|
" target_modules=LORA_TARGET_MODULES,\n",
|
||||||
" rank=LORA_RANK,\n",
|
" r=LORA_RANK,\n",
|
||||||
" lora_alpha=LORA_ALPHA,\n",
|
" lora_alpha=LORA_ALPHA,\n",
|
||||||
" lora_dropout=LORA_DROPOUT_P)\n",
|
" lora_dropout=LORA_DROPOUT_P)\n",
|
||||||
"logger.info(f'lora_config: {lora_config}')\n",
|
"logger.info(f'lora_config: {lora_config}')\n",
|
||||||
|
|||||||
@@ -2,16 +2,10 @@
|
|||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.metrics import accuracy_score, f1_score
|
|
||||||
|
|
||||||
from modelscope.metainfo import Metrics
|
from modelscope.metainfo import Metrics
|
||||||
from modelscope.outputs import OutputKeys
|
|
||||||
from modelscope.utils.registry import default_group
|
from modelscope.utils.registry import default_group
|
||||||
from modelscope.utils.tensor_utils import (torch_nested_detach,
|
|
||||||
torch_nested_numpify)
|
|
||||||
from .base import Metric
|
from .base import Metric
|
||||||
from .builder import METRICS, MetricKeys
|
from .builder import METRICS
|
||||||
|
|
||||||
|
|
||||||
@METRICS.register_module(
|
@METRICS.register_module(
|
||||||
|
|||||||
@@ -13,22 +13,20 @@ from diffusers import (AutoencoderKL, DDPMScheduler, DiffusionPipeline,
|
|||||||
utils)
|
utils)
|
||||||
from diffusers.models import cross_attention
|
from diffusers.models import cross_attention
|
||||||
from diffusers.utils import deprecation_utils
|
from diffusers.utils import deprecation_utils
|
||||||
|
from swift import AdapterConfig, LoRAConfig, PromptConfig, Swift
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.models import TorchModel
|
from modelscope.models import TorchModel
|
||||||
from modelscope.models.builder import MODELS
|
from modelscope.models.builder import MODELS
|
||||||
|
from modelscope.models.multi_modal.efficient_diffusion_tuning.sd_lora import \
|
||||||
|
LoRATuner
|
||||||
from modelscope.outputs import OutputKeys
|
from modelscope.outputs import OutputKeys
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.adapter import AdapterConfig
|
|
||||||
from modelscope.swift.control_sd_lora import ControlLoRATuner
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.swift.prompt import PromptConfig
|
|
||||||
from modelscope.swift.sd_lora import LoRATuner
|
|
||||||
from modelscope.utils.checkpoint import save_checkpoint, save_configuration
|
from modelscope.utils.checkpoint import save_checkpoint, save_configuration
|
||||||
from modelscope.utils.config import Config
|
from modelscope.utils.config import Config
|
||||||
from modelscope.utils.constant import ModelFile, Tasks
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from .control_sd_lora import ControlLoRATuner
|
||||||
|
|
||||||
utils.deprecate = lambda *arg, **kwargs: None
|
utils.deprecate = lambda *arg, **kwargs: None
|
||||||
deprecation_utils.deprecate = lambda *arg, **kwargs: None
|
deprecation_utils.deprecate = lambda *arg, **kwargs: None
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from modelscope.utils.import_utils import LazyImportModule
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .optimizers.child_tuning_adamw_optimizer import calculate_fisher, ChildTuningAdamW
|
|
||||||
from .adapter import Adapter, AdapterConfig, AdapterModule
|
|
||||||
from .lora import LoRA, LoRAConfig, Linear, MergedLinear, Embedding, Conv2d
|
|
||||||
from .prompt import Prompt, PromptConfig, PromptModule
|
|
||||||
from .control_sd_lora import ControlLoRACrossAttnProcessor, ControlLoRACrossAttnProcessorV2, ControlLoRATuner
|
|
||||||
from .base import SwiftConfig, Swift
|
|
||||||
else:
|
|
||||||
_import_structure = {
|
|
||||||
'optimizers.child_tuning_adamw_optimizer':
|
|
||||||
['calculate_fisher', 'ChildTuningAdamW'],
|
|
||||||
'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'],
|
|
||||||
'lora': [
|
|
||||||
'LoRA', 'LoRAConfig', 'Linear', 'MergedLinear', 'Embedding',
|
|
||||||
'Conv2d'
|
|
||||||
],
|
|
||||||
'prompt': ['Prompt', 'PromptConfig', 'PromptModule'],
|
|
||||||
'control_sd_lora': [
|
|
||||||
'ControlLoRACrossAttnProcessor', 'ControlLoRACrossAttnProcessorV2',
|
|
||||||
'ControlLoRATuner'
|
|
||||||
],
|
|
||||||
'base': ['SwiftConfig', 'Swift']
|
|
||||||
}
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.modules[__name__] = LazyImportModule(
|
|
||||||
__name__,
|
|
||||||
globals()['__file__'],
|
|
||||||
_import_structure,
|
|
||||||
module_spec=__spec__,
|
|
||||||
extra_objects={},
|
|
||||||
)
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
import inspect
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import types
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from modelscope.utils.constant import ModelFile
|
|
||||||
from .base import SwiftConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AdapterConfig(SwiftConfig):
|
|
||||||
"""
|
|
||||||
The configuration class for the adapter module.
|
|
||||||
|
|
||||||
Adapters project input tokens by an MLP layer.
|
|
||||||
'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
|
|
||||||
See http://arxiv.org/abs/1902.00751
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim: The dimension of the hidden states
|
|
||||||
module_name: The feedforward module to be replaced, in regex format
|
|
||||||
hidden_pos: The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs)
|
|
||||||
method_name: The method to be replaced, default to replace the forward method
|
|
||||||
adapter_length: The length of the adapter length (intermediate length)
|
|
||||||
act_layer: The activation layer of the adapter
|
|
||||||
only_adapter_trainable: Whether to train only adapters
|
|
||||||
pretrained_weights: The pretrained adapter weights.
|
|
||||||
Can be a local dir, local file, or a model id from modelscope
|
|
||||||
"""
|
|
||||||
|
|
||||||
dim: int = field(metadata={'help': 'The dimension of the hidden states'})
|
|
||||||
|
|
||||||
module_name: str = field(
|
|
||||||
metadata={
|
|
||||||
'help': 'The feedforward module to be replaced, in regex format'
|
|
||||||
})
|
|
||||||
|
|
||||||
hidden_pos: Union[str, int] = field(
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs)'
|
|
||||||
})
|
|
||||||
|
|
||||||
method_name: str = field(
|
|
||||||
default='forward',
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'The method to be replaced, default to replace the forward method'
|
|
||||||
})
|
|
||||||
|
|
||||||
adapter_length: int = field(
|
|
||||||
default=128,
|
|
||||||
metadata={
|
|
||||||
'help': 'The length of the adapter length (intermediate length)'
|
|
||||||
})
|
|
||||||
|
|
||||||
act_layer: nn.Module = field(
|
|
||||||
default=nn.GELU,
|
|
||||||
metadata={'help': 'The activation layer of the adapter'})
|
|
||||||
|
|
||||||
only_adapter_trainable: bool = field(
|
|
||||||
default=True, metadata={'help': 'Whether to train only adapters'})
|
|
||||||
|
|
||||||
pretrained_weights: str = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'The pretrained adapter weights. Can be a local dir, local file, or a model id from modelscope'
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class Adapter:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare_model(model: nn.Module, config: AdapterConfig):
|
|
||||||
module_keys = [key for key, _ in model.named_modules()]
|
|
||||||
|
|
||||||
for module_key in module_keys:
|
|
||||||
if re.fullmatch(config.module_name, module_key): # noqa
|
|
||||||
module = model.get_submodule(module_key)
|
|
||||||
|
|
||||||
def _forward(self, *args, **kwargs):
|
|
||||||
args = self.forward_origin(*args, **kwargs)
|
|
||||||
if isinstance(args, (tuple, list, dict)):
|
|
||||||
if isinstance(config.hidden_pos, int):
|
|
||||||
return args[0:config.hidden_pos] + args[
|
|
||||||
config.hidden_pos] + getattr(self, 'adapter')(args[config.hidden_pos]) \
|
|
||||||
+ args[config.hidden_pos + 1:] # noqa
|
|
||||||
else:
|
|
||||||
kwargs[config.hidden_pos] = args[
|
|
||||||
config.hidden_pos] + getattr(self, 'adapter')(
|
|
||||||
args[config.hidden_pos])
|
|
||||||
elif isinstance(args, torch.Tensor):
|
|
||||||
args = getattr(self, 'adapter')(args)
|
|
||||||
return args
|
|
||||||
|
|
||||||
def _feed_forward_chunk(self, attention_output):
|
|
||||||
return _forward(self, attention_output)
|
|
||||||
|
|
||||||
module.forward_origin = getattr(module, config.method_name)
|
|
||||||
num_args_in_forward_chunk_fn = len(
|
|
||||||
inspect.signature(module.forward_origin).parameters)
|
|
||||||
if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1:
|
|
||||||
setattr(module, config.method_name,
|
|
||||||
types.MethodType(_feed_forward_chunk, module))
|
|
||||||
else:
|
|
||||||
setattr(module, config.method_name,
|
|
||||||
types.MethodType(_forward, module))
|
|
||||||
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
|
||||||
input_dim = module.out_features
|
|
||||||
else:
|
|
||||||
input_dim = config.dim
|
|
||||||
|
|
||||||
adapter_module = AdapterModule(input_dim,
|
|
||||||
config.adapter_length,
|
|
||||||
config.act_layer)
|
|
||||||
setattr(module, 'adapter', adapter_module)
|
|
||||||
|
|
||||||
if config.only_adapter_trainable:
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if 'adapter' not in n:
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
def state_dict_hook(module, destination, prefix, local_metadata):
|
|
||||||
return {
|
|
||||||
key: value
|
|
||||||
for key, value in destination.items() if 'adapter' in key
|
|
||||||
}
|
|
||||||
|
|
||||||
model.state_dict_hook_handle = model._register_state_dict_hook(
|
|
||||||
state_dict_hook)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
|
||||||
return self.load_state_dict_origin(state_dict, False)
|
|
||||||
|
|
||||||
model.load_state_dict_origin = model.load_state_dict
|
|
||||||
model.load_state_dict = types.MethodType(load_state_dict, model)
|
|
||||||
|
|
||||||
if config.pretrained_weights is not None:
|
|
||||||
if not os.path.exists(config.pretrained_weights):
|
|
||||||
model_dir = snapshot_download(config.pretrained_weights)
|
|
||||||
pretrained_weights = os.path.join(
|
|
||||||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE)
|
|
||||||
elif os.path.isfile(config.pretrained_weights):
|
|
||||||
pretrained_weights = config.pretrained_weights
|
|
||||||
else:
|
|
||||||
pretrained_weights = os.path.join(
|
|
||||||
config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE)
|
|
||||||
model.load_state_dict(torch.load(pretrained_weights))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class AdapterModule(nn.Module):
|
|
||||||
"""The implementation of adapter tuning method.
|
|
||||||
|
|
||||||
Adapters project input tokens by an MLP layer.
|
|
||||||
'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
|
|
||||||
See http://arxiv.org/abs/1902.00751
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
dim: An integer indicating the embedding dimension.
|
|
||||||
adapter_length: An integer indicating the length of adapter tuning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
adapter_length=None,
|
|
||||||
act_layer=nn.GELU,
|
|
||||||
):
|
|
||||||
super(AdapterModule, self).__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.adapter_length = adapter_length
|
|
||||||
# self.adapter_type = adapter_type
|
|
||||||
self.ln1 = nn.Linear(dim, adapter_length)
|
|
||||||
self.activate = act_layer()
|
|
||||||
self.ln2 = nn.Linear(adapter_length, dim)
|
|
||||||
self.init_weights()
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
|
|
||||||
def _init_weights(m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
nn.init.xavier_uniform_(m.weight)
|
|
||||||
nn.init.normal_(m.bias, std=1e-6)
|
|
||||||
|
|
||||||
self.apply(_init_weights)
|
|
||||||
|
|
||||||
def forward(self, x, identity=None):
|
|
||||||
out = self.ln2(self.activate(self.ln1(x)))
|
|
||||||
if identity is None:
|
|
||||||
identity = x
|
|
||||||
out = identity + out
|
|
||||||
return out
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SwiftConfig:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Swift:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare_model(model, config: SwiftConfig):
|
|
||||||
"""Prepare the module and returns the new module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to tune.
|
|
||||||
config: The config of the tuner.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The tuned model.
|
|
||||||
"""
|
|
||||||
from .lora import LoRA, LoRAConfig
|
|
||||||
from .adapter import Adapter, AdapterConfig
|
|
||||||
from .prompt import Prompt, PromptConfig
|
|
||||||
if isinstance(config, LoRAConfig):
|
|
||||||
return LoRA.prepare_model(model, config)
|
|
||||||
elif isinstance(config, AdapterConfig):
|
|
||||||
return Adapter.prepare_model(model, config)
|
|
||||||
elif isinstance(config, PromptConfig):
|
|
||||||
return Prompt.prepare_model(model, config)
|
|
||||||
return None
|
|
||||||
@@ -1,700 +0,0 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os.path
|
|
||||||
import re
|
|
||||||
import types
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from modelscope.utils.constant import ModelFile
|
|
||||||
from .base import SwiftConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoRAConfig(SwiftConfig):
|
|
||||||
"""
|
|
||||||
The configuration class for the loRA module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rank: The rank of the LoRA module
|
|
||||||
replace_modules: The modules to be replaced by LoRA, can be the end of the module name or a regex string
|
|
||||||
lora_alpha: The factor to add the lora weights
|
|
||||||
lora_dropout: The dropout rate of the lora module
|
|
||||||
merge_weights: Whether to merge weights when validating
|
|
||||||
use_merged_linear: Whether to replace with merged linear layer
|
|
||||||
enable_lora: The modules need to be turned on when using the merged linear layer
|
|
||||||
fan_in_fan_out: Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
|
||||||
bias: Bias type. Values ca be "none", "all" or "lora_only"
|
|
||||||
only_lora_trainable: Whether to train only lora
|
|
||||||
pretrained_weights: The pretrained lora weights.
|
|
||||||
Can be a local dir, local file, or a model id from modelscope
|
|
||||||
"""
|
|
||||||
|
|
||||||
rank: int = field(
|
|
||||||
default=6, metadata={'help': 'The rank of the LoRA module'})
|
|
||||||
replace_modules: List = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'The modules to be replaced by LoRA, can be the end of the module name or a regex string'
|
|
||||||
})
|
|
||||||
lora_alpha: float = field(
|
|
||||||
default=1., metadata={'help': 'The factor to add the lora weights'})
|
|
||||||
lora_dropout: float = field(
|
|
||||||
default=0., metadata={'help': 'The dropout rate of the lora module'})
|
|
||||||
merge_weights: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={'help': 'Whether to merge weights when validating'})
|
|
||||||
use_merged_linear: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={'help': 'Whether to replace with merged linear layer'})
|
|
||||||
enable_lora: List = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'The modules need to be turned on when using the merged linear layer'
|
|
||||||
})
|
|
||||||
fan_in_fan_out: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'Set this to True if the layer to replace stores weight like (fan_in, fan_out)'
|
|
||||||
})
|
|
||||||
bias: str = field(
|
|
||||||
default='none',
|
|
||||||
metadata={
|
|
||||||
'help': 'Bias type. Values ca be "none", "all" or "lora_only"'
|
|
||||||
})
|
|
||||||
only_lora_trainable: bool = field(
|
|
||||||
default=True, metadata={'help': 'Whether to train only lora'})
|
|
||||||
pretrained_weights: str = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'The pretrained lora weights. Can be a local dir, local file, or a model id from modelscope'
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class LoRA:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare_model(model: nn.Module, config: LoRAConfig):
|
|
||||||
"""Tune a model with LoRA.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: The LoRAConfig instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The lora modules
|
|
||||||
"""
|
|
||||||
LoRA._dynamic_patch_lora(
|
|
||||||
model,
|
|
||||||
replace_modules=config.replace_modules,
|
|
||||||
r=config.rank,
|
|
||||||
lora_alpha=config.lora_alpha,
|
|
||||||
lora_dropout=config.lora_dropout,
|
|
||||||
merge_weights=config.merge_weights,
|
|
||||||
use_merged_linear=config.use_merged_linear,
|
|
||||||
enable_lora=config.enable_lora,
|
|
||||||
fan_in_fan_out=config.fan_in_fan_out)
|
|
||||||
|
|
||||||
if config.only_lora_trainable:
|
|
||||||
mark_only_lora_as_trainable(model, config.bias)
|
|
||||||
|
|
||||||
def state_dict_hook(module, destination, prefix, local_metadata):
|
|
||||||
return lora_state_dict(destination, config.bias)
|
|
||||||
|
|
||||||
model.state_dict_hook_handle = model._register_state_dict_hook(
|
|
||||||
state_dict_hook)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
|
||||||
return self.load_state_dict_origin(state_dict, False)
|
|
||||||
|
|
||||||
model.load_state_dict_origin = model.load_state_dict
|
|
||||||
model.load_state_dict = types.MethodType(load_state_dict, model)
|
|
||||||
|
|
||||||
if config.pretrained_weights is not None:
|
|
||||||
if not os.path.exists(config.pretrained_weights):
|
|
||||||
model_dir = snapshot_download(config.pretrained_weights)
|
|
||||||
pretrained_weights = os.path.join(
|
|
||||||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE)
|
|
||||||
elif os.path.isfile(config.pretrained_weights):
|
|
||||||
pretrained_weights = config.pretrained_weights
|
|
||||||
else:
|
|
||||||
pretrained_weights = os.path.join(
|
|
||||||
config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE)
|
|
||||||
model.load_state_dict(torch.load(pretrained_weights))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dynamic_patch_lora(model, replace_modules, use_merged_linear,
|
|
||||||
**kwargs):
|
|
||||||
"""Dynamic patch lora to model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The torch.nn.Module containing the target module to be patched.
|
|
||||||
replace_modules: The module names to be replaced, the replacing strategy is `end with`.
|
|
||||||
use_merged_linear: Whether to replace with merged linear layer
|
|
||||||
**kwargs: The arguments passed from `tune` which are needed by lora.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The lora modules
|
|
||||||
"""
|
|
||||||
modules = []
|
|
||||||
module_keys = [key for key, _ in model.named_modules()]
|
|
||||||
assert isinstance(replace_modules, (str, list))
|
|
||||||
if isinstance(replace_modules, str):
|
|
||||||
replace_modules = [replace_modules]
|
|
||||||
|
|
||||||
for module_key in module_keys:
|
|
||||||
if isinstance(replace_modules, str):
|
|
||||||
target_module_found = re.fullmatch(replace_modules, module_key)
|
|
||||||
else:
|
|
||||||
target_module_found = any(
|
|
||||||
module_key.endswith(target_key)
|
|
||||||
for target_key in replace_modules)
|
|
||||||
if target_module_found: # noqa
|
|
||||||
parts = module_key.split('.')
|
|
||||||
module = model.get_submodule('.'.join(parts[:-1]))
|
|
||||||
sub_module = model.get_submodule(module_key)
|
|
||||||
_key = parts[-1]
|
|
||||||
|
|
||||||
lora_module = None
|
|
||||||
if isinstance(sub_module, torch.nn.Linear):
|
|
||||||
if use_merged_linear:
|
|
||||||
lora_module = MergedLinear(
|
|
||||||
sub_module.in_features,
|
|
||||||
sub_module.out_features,
|
|
||||||
bias=sub_module.bias is not None,
|
|
||||||
**kwargs)
|
|
||||||
else:
|
|
||||||
kwargs.pop('enable_lora', None)
|
|
||||||
lora_module = Linear(
|
|
||||||
sub_module.in_features,
|
|
||||||
sub_module.out_features,
|
|
||||||
bias=sub_module.bias is not None,
|
|
||||||
**kwargs)
|
|
||||||
elif isinstance(sub_module, torch.nn.Conv2d):
|
|
||||||
kwargs.pop('fan_in_fan_out', None)
|
|
||||||
lora_module = Conv2d(
|
|
||||||
sub_module.in_channels,
|
|
||||||
sub_module.out_channels,
|
|
||||||
kernel_size=sub_module.kernel_size,
|
|
||||||
stride=sub_module.stride,
|
|
||||||
padding=sub_module.padding,
|
|
||||||
dilation=sub_module.dilation,
|
|
||||||
groups=sub_module.groups,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
if lora_module is not None:
|
|
||||||
lora_module.weight = sub_module.weight
|
|
||||||
if sub_module.bias is not None:
|
|
||||||
lora_module.bias = sub_module.bias
|
|
||||||
lora_module.to(sub_module.weight.device).to(
|
|
||||||
sub_module.weight.dtype)
|
|
||||||
setattr(module, _key, lora_module)
|
|
||||||
modules.append(lora_module)
|
|
||||||
return modules
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def unpatch_lora(model, config: LoRAConfig):
|
|
||||||
"""Unpatch lora modules and merge the weights to original modules.
|
|
||||||
|
|
||||||
LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
|
|
||||||
'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
|
|
||||||
See https://arxiv.org/abs/2106.09685
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model called with `tune` function.
|
|
||||||
replace_modules: The module names to be replaced, the replacing strategy is `end with`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The lora modules.
|
|
||||||
"""
|
|
||||||
modules = []
|
|
||||||
module_keys = [key for key, _ in model.named_modules()]
|
|
||||||
assert isinstance(config.replace_modules, (str, list))
|
|
||||||
replace_modules = config.replace_modules
|
|
||||||
|
|
||||||
for module_key in module_keys:
|
|
||||||
if isinstance(replace_modules, str):
|
|
||||||
target_module_found = re.fullmatch(replace_modules, module_key)
|
|
||||||
else:
|
|
||||||
target_module_found = any(
|
|
||||||
module_key.endswith(target_key)
|
|
||||||
for target_key in replace_modules)
|
|
||||||
if target_module_found: # noqa
|
|
||||||
parts = module_key.split('.')
|
|
||||||
module = model.get_submodule('.'.join(parts[:-1]))
|
|
||||||
sub_module = model.get_submodule(module_key)
|
|
||||||
_key = parts[-1]
|
|
||||||
|
|
||||||
origin_module = None
|
|
||||||
if isinstance(sub_module, Linear):
|
|
||||||
origin_module = torch.nn.Linear(
|
|
||||||
sub_module.in_features,
|
|
||||||
sub_module.out_features,
|
|
||||||
bias=sub_module.bias is not None)
|
|
||||||
elif isinstance(sub_module, Conv2d):
|
|
||||||
origin_module = torch.nn.Conv2d(
|
|
||||||
sub_module.in_channels,
|
|
||||||
sub_module.out_channels,
|
|
||||||
kernel_size=sub_module.kernel_size,
|
|
||||||
stride=sub_module.stride,
|
|
||||||
padding=sub_module.padding,
|
|
||||||
dilation=sub_module.dilation,
|
|
||||||
groups=sub_module.groups)
|
|
||||||
|
|
||||||
if origin_module is not None:
|
|
||||||
sub_module.merge_weights = True
|
|
||||||
sub_module.eval()
|
|
||||||
origin_module.weight = sub_module.weight
|
|
||||||
if sub_module.bias is not None:
|
|
||||||
origin_module.bias = sub_module.bias
|
|
||||||
origin_module.to(sub_module.weight.device).to(
|
|
||||||
sub_module.weight.dtype)
|
|
||||||
setattr(module, _key, origin_module)
|
|
||||||
modules.append(sub_module)
|
|
||||||
|
|
||||||
model.state_dict_hook_handle.remove()
|
|
||||||
if hasattr(model, 'load_state_dict_hook_handle'):
|
|
||||||
model.load_state_dict_hook_handle.remove()
|
|
||||||
else:
|
|
||||||
model.load_state_dict = model.load_state_dict_origin
|
|
||||||
return modules
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayer:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
r: int,
|
|
||||||
lora_alpha: int,
|
|
||||||
lora_dropout: float,
|
|
||||||
merge_weights: bool,
|
|
||||||
):
|
|
||||||
self.r = r
|
|
||||||
self.lora_alpha = lora_alpha
|
|
||||||
# Optional dropout
|
|
||||||
if lora_dropout > 0.:
|
|
||||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
|
||||||
else:
|
|
||||||
self.lora_dropout = lambda x: x
|
|
||||||
# Mark the weight as unmerged
|
|
||||||
self.merged = False
|
|
||||||
self.merge_weights = merge_weights
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Embedding, LoRALayer):
|
|
||||||
# LoRA implemented in a dense layer
|
|
||||||
def __init__(self,
|
|
||||||
num_embeddings: int,
|
|
||||||
embedding_dim: int,
|
|
||||||
r: int = 0,
|
|
||||||
lora_alpha: int = 1,
|
|
||||||
merge_weights: bool = True,
|
|
||||||
**kwargs):
|
|
||||||
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
|
|
||||||
LoRALayer.__init__(
|
|
||||||
self,
|
|
||||||
r=r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
lora_dropout=0,
|
|
||||||
merge_weights=merge_weights)
|
|
||||||
# Actual trainable parameters
|
|
||||||
if r > 0:
|
|
||||||
self.lora_A = nn.Parameter(
|
|
||||||
self.weight.new_zeros((r, num_embeddings)))
|
|
||||||
self.lora_B = nn.Parameter(
|
|
||||||
self.weight.new_zeros((embedding_dim, r)))
|
|
||||||
self.scaling = self.lora_alpha / self.r
|
|
||||||
# Freezing the pre-trained weight matrix
|
|
||||||
self.weight.requires_grad = False
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.Embedding.reset_parameters(self)
|
|
||||||
if hasattr(self, 'lora_A'):
|
|
||||||
# initialize A the same way as the default for nn.Linear and B to zero
|
|
||||||
nn.init.zeros_(self.lora_A)
|
|
||||||
nn.init.normal_(self.lora_B)
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
|
||||||
nn.Embedding.train(self, mode)
|
|
||||||
self.lora_A.requires_grad = mode
|
|
||||||
self.lora_B.requires_grad = mode
|
|
||||||
if mode and self.merge_weights and self.merged:
|
|
||||||
# Make sure that the weights are not merged
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data -= (self.lora_B
|
|
||||||
@ self.lora_A).T * self.scaling
|
|
||||||
self.merged = False
|
|
||||||
if not mode and self.merge_weights and not self.merged:
|
|
||||||
# Merge the weights and mark it
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data += (self.lora_B
|
|
||||||
@ self.lora_A).T * self.scaling
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
nn.Embedding.eval(self)
|
|
||||||
self.lora_A.requires_grad = False
|
|
||||||
self.lora_B.requires_grad = False
|
|
||||||
if self.merge_weights and not self.merged:
|
|
||||||
# Merge the weights and mark it
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.r > 0 and not self.merged:
|
|
||||||
result = nn.Embedding.forward(self, x)
|
|
||||||
if self.r > 0:
|
|
||||||
after_A = F.embedding(x, self.lora_A.T, self.padding_idx,
|
|
||||||
self.max_norm, self.norm_type,
|
|
||||||
self.scale_grad_by_freq, self.sparse)
|
|
||||||
result += (after_A @ self.lora_B.T) * self.scaling
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return nn.Embedding.forward(self, x)
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Linear, LoRALayer):
|
|
||||||
# LoRA implemented in a dense layer
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
r: int = 0,
|
|
||||||
lora_alpha: int = 1,
|
|
||||||
lora_dropout: float = 0.,
|
|
||||||
fan_in_fan_out: bool = False,
|
|
||||||
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
|
||||||
merge_weights: bool = True,
|
|
||||||
**kwargs):
|
|
||||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
|
||||||
LoRALayer.__init__(
|
|
||||||
self,
|
|
||||||
r=r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
merge_weights=merge_weights)
|
|
||||||
|
|
||||||
self.fan_in_fan_out = fan_in_fan_out
|
|
||||||
# Actual trainable parameters
|
|
||||||
if r > 0:
|
|
||||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
|
||||||
self.lora_B = nn.Parameter(
|
|
||||||
self.weight.new_zeros((out_features, r)))
|
|
||||||
self.scaling = self.lora_alpha / self.r
|
|
||||||
# Freezing the pre-trained weight matrix
|
|
||||||
self.weight.requires_grad = False
|
|
||||||
self.reset_parameters()
|
|
||||||
if fan_in_fan_out:
|
|
||||||
self.weight.data = self.weight.data.T
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.Linear.reset_parameters(self)
|
|
||||||
if hasattr(self, 'lora_A'):
|
|
||||||
# initialize A the same way as the default for nn.Linear and B to zero
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
||||||
nn.init.zeros_(self.lora_B)
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
nn.Linear.train(self, mode)
|
|
||||||
self.lora_A.requires_grad = mode
|
|
||||||
self.lora_B.requires_grad = mode
|
|
||||||
if mode and self.merge_weights and self.merged:
|
|
||||||
# Make sure that the weights are not merged
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self.merged = False
|
|
||||||
if not mode and self.merge_weights and not self.merged:
|
|
||||||
# Merge the weights and mark it
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
nn.Linear.eval(self)
|
|
||||||
self.lora_A.requires_grad = False
|
|
||||||
self.lora_B.requires_grad = False
|
|
||||||
if self.merge_weights and not self.merged:
|
|
||||||
# Merge the weights and mark it
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
if self.r > 0 and not self.merged:
|
|
||||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
|
||||||
if self.r > 0:
|
|
||||||
result += (self.lora_dropout(x) @ self.lora_A.T
|
|
||||||
@ self.lora_B.T) * self.scaling
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return F.linear(x, T(self.weight), bias=self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class MergedLinear(nn.Linear, LoRALayer):
|
|
||||||
# LoRA implemented in a dense layer
|
|
||||||
def __init__(self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
r: int = 0,
|
|
||||||
lora_alpha: int = 1,
|
|
||||||
lora_dropout: float = 0.,
|
|
||||||
enable_lora: List[bool] = [False],
|
|
||||||
fan_in_fan_out: bool = False,
|
|
||||||
merge_weights: bool = True,
|
|
||||||
**kwargs):
|
|
||||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
|
||||||
LoRALayer.__init__(
|
|
||||||
self,
|
|
||||||
r=r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
merge_weights=merge_weights)
|
|
||||||
assert out_features % len(enable_lora) == 0, \
|
|
||||||
'The length of enable_lora must divide out_features'
|
|
||||||
self.enable_lora = enable_lora
|
|
||||||
self.fan_in_fan_out = fan_in_fan_out
|
|
||||||
# Actual trainable parameters
|
|
||||||
if r > 0 and any(enable_lora):
|
|
||||||
self.lora_A = nn.Parameter(
|
|
||||||
self.weight.new_zeros((r * sum(enable_lora), in_features)))
|
|
||||||
self.lora_B = nn.Parameter(
|
|
||||||
self.weight.new_zeros(
|
|
||||||
(out_features // len(enable_lora) * sum(enable_lora),
|
|
||||||
r))) # weights for Conv1D with groups=sum(enable_lora)
|
|
||||||
self.scaling = self.lora_alpha / self.r
|
|
||||||
# Freezing the pre-trained weight matrix
|
|
||||||
self.weight.requires_grad = False
|
|
||||||
# Compute the indices
|
|
||||||
self.lora_ind = self.weight.new_zeros(
|
|
||||||
(out_features, ), dtype=torch.bool).view(len(enable_lora), -1)
|
|
||||||
self.lora_ind[enable_lora, :] = True
|
|
||||||
self.lora_ind = self.lora_ind.view(-1)
|
|
||||||
self.reset_parameters()
|
|
||||||
if fan_in_fan_out:
|
|
||||||
self.weight.data = self.weight.data.T
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.Linear.reset_parameters(self)
|
|
||||||
if hasattr(self, 'lora_A'):
|
|
||||||
# initialize A the same way as the default for nn.Linear and B to zero
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
||||||
nn.init.zeros_(self.lora_B)
|
|
||||||
|
|
||||||
def zero_pad(self, x):
|
|
||||||
result = x.new_zeros((*x.shape[:-1], self.out_features))
|
|
||||||
result = result.view(-1, self.out_features)
|
|
||||||
result[:, self.lora_ind] = x.reshape(
|
|
||||||
-1,
|
|
||||||
self.out_features // len(self.enable_lora) * sum(self.enable_lora))
|
|
||||||
return result.view((*x.shape[:-1], self.out_features))
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
nn.Linear.train(self, mode)
|
|
||||||
self.lora_A.requires_grad = mode
|
|
||||||
self.lora_B.requires_grad = mode
|
|
||||||
if mode and self.merge_weights and self.merged:
|
|
||||||
# Make sure that the weights are not merged
|
|
||||||
if self.r > 0 and any(self.enable_lora):
|
|
||||||
delta_w = F.conv1d(
|
|
||||||
self.lora_A.data.unsqueeze(0),
|
|
||||||
self.lora_B.data.unsqueeze(-1),
|
|
||||||
groups=sum(self.enable_lora)).squeeze(0)
|
|
||||||
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
|
|
||||||
self.merged = False
|
|
||||||
if not mode and self.merge_weights and not self.merged:
|
|
||||||
if self.r > 0 and any(self.enable_lora):
|
|
||||||
delta_w = F.conv1d(
|
|
||||||
self.lora_A.data.unsqueeze(0),
|
|
||||||
self.lora_B.data.unsqueeze(-1),
|
|
||||||
groups=sum(self.enable_lora)).squeeze(0)
|
|
||||||
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
nn.Linear.eval(self)
|
|
||||||
self.lora_A.requires_grad = False
|
|
||||||
self.lora_B.requires_grad = False
|
|
||||||
if self.merge_weights and not self.merged:
|
|
||||||
# Merge the weights and mark it
|
|
||||||
if self.r > 0 and any(self.enable_lora):
|
|
||||||
delta_w = F.conv1d(
|
|
||||||
self.lora_A.data.unsqueeze(0),
|
|
||||||
self.lora_B.data.unsqueeze(-1),
|
|
||||||
groups=sum(self.enable_lora)).squeeze(0)
|
|
||||||
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
if self.merged:
|
|
||||||
return F.linear(x, T(self.weight), bias=self.bias)
|
|
||||||
else:
|
|
||||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
|
||||||
if self.r > 0:
|
|
||||||
after_A = F.linear(self.lora_dropout(x), self.lora_A)
|
|
||||||
after_B = F.conv1d(
|
|
||||||
after_A.transpose(-2, -1),
|
|
||||||
self.lora_B.unsqueeze(-1),
|
|
||||||
groups=sum(self.enable_lora)).transpose(-2, -1)
|
|
||||||
result += self.zero_pad(after_B) * self.scaling
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d(nn.Conv2d, LoRALayer):
|
|
||||||
# LoRA implemented in a dense layer
|
|
||||||
def __init__(self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int,
|
|
||||||
r: int = 0,
|
|
||||||
lora_alpha: int = 1,
|
|
||||||
lora_dropout: float = 0.,
|
|
||||||
merge_weights: bool = True,
|
|
||||||
**kwargs):
|
|
||||||
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size,
|
|
||||||
**kwargs)
|
|
||||||
LoRALayer.__init__(
|
|
||||||
self,
|
|
||||||
r=r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
merge_weights=merge_weights)
|
|
||||||
assert type(kernel_size) is int
|
|
||||||
# Actual trainable parameters
|
|
||||||
if r > 0:
|
|
||||||
self.lora_A = nn.Parameter(
|
|
||||||
self.weight.new_zeros(
|
|
||||||
(r * kernel_size, in_channels * kernel_size)))
|
|
||||||
self.lora_B = nn.Parameter(
|
|
||||||
self.weight.new_zeros(
|
|
||||||
(out_channels * kernel_size, r * kernel_size)))
|
|
||||||
self.scaling = self.lora_alpha / self.r
|
|
||||||
# Freezing the pre-trained weight matrix
|
|
||||||
self.weight.requires_grad = False
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.Conv2d.reset_parameters(self)
|
|
||||||
if hasattr(self, 'lora_A'):
|
|
||||||
# initialize A the same way as the default for nn.Linear and B to zero
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
||||||
nn.init.zeros_(self.lora_B)
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
|
||||||
nn.Conv2d.train(self, mode)
|
|
||||||
self.lora_A.requires_grad = mode
|
|
||||||
self.lora_B.requires_grad = mode
|
|
||||||
if mode and self.merge_weights and self.merged:
|
|
||||||
# Make sure that the weights are not merged
|
|
||||||
self.weight.data -= (self.lora_B @ self.lora_A).view(
|
|
||||||
self.weight.shape) * self.scaling
|
|
||||||
self.merged = False
|
|
||||||
if not mode and self.merge_weights and not self.merged:
|
|
||||||
self.weight.data += (self.lora_B @ self.lora_A).view(
|
|
||||||
self.weight.shape) * self.scaling
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
nn.Conv2d.eval(self)
|
|
||||||
self.lora_A.requires_grad = False
|
|
||||||
self.lora_B.requires_grad = False
|
|
||||||
if self.merge_weights and not self.merged:
|
|
||||||
# Merge the weights and mark it
|
|
||||||
self.weight.data += (self.lora_B @ self.lora_A).view(
|
|
||||||
self.weight.shape) * self.scaling
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.r > 0 and not self.merged:
|
|
||||||
return F.conv2d(
|
|
||||||
x,
|
|
||||||
self.weight + # noqa
|
|
||||||
(self.lora_B @ self.lora_A).view(self.weight.shape) # noqa
|
|
||||||
* self.scaling,
|
|
||||||
self.bias,
|
|
||||||
self.stride,
|
|
||||||
self.padding,
|
|
||||||
self.dilation,
|
|
||||||
self.groups)
|
|
||||||
return nn.Conv2d.forward(self, x)
|
|
||||||
|
|
||||||
|
|
||||||
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if 'lora_' not in n:
|
|
||||||
p.requires_grad = False
|
|
||||||
if bias == 'none':
|
|
||||||
return
|
|
||||||
elif bias == 'all':
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if 'bias' in n:
|
|
||||||
p.requires_grad = True
|
|
||||||
elif bias == 'lora_only':
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALayer) and \
|
|
||||||
hasattr(m, 'bias') and \
|
|
||||||
m.bias is not None:
|
|
||||||
m.bias.requires_grad = True
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def lora_state_dict(state_dict, bias: str = 'none') -> Dict[str, torch.Tensor]:
|
|
||||||
if bias == 'none':
|
|
||||||
return {k: state_dict[k] for k in state_dict if 'lora_' in k}
|
|
||||||
elif bias == 'all':
|
|
||||||
return {
|
|
||||||
k: state_dict[k]
|
|
||||||
for k in state_dict if 'lora_' in k or 'bias' in k
|
|
||||||
}
|
|
||||||
elif bias == 'lora_only':
|
|
||||||
to_return = {}
|
|
||||||
for k in state_dict:
|
|
||||||
if 'lora_' in k:
|
|
||||||
to_return[k] = state_dict[k]
|
|
||||||
bias_name = k.split('lora_')[0] + 'bias'
|
|
||||||
if bias_name in state_dict:
|
|
||||||
to_return[bias_name] = state_dict[bias_name]
|
|
||||||
return to_return
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,241 +0,0 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import types
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from modelscope.utils.constant import ModelFile
|
|
||||||
from .base import SwiftConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PromptConfig(SwiftConfig):
|
|
||||||
"""
|
|
||||||
The configuration class for the prompt module.
|
|
||||||
|
|
||||||
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
|
|
||||||
and prepend to the original tokens in the first layer or multiple layers.
|
|
||||||
'Visual Prompt Tuning' by Jia et al.(2022)
|
|
||||||
See https://arxiv.org/abs/2203.12119
|
|
||||||
|
|
||||||
Here we apply the VPT to other fields.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim: The dimension of the hidden states
|
|
||||||
module_layer_name: The layer module to be replaced, in regex format
|
|
||||||
embedding_pos: The position of the embedding tensor
|
|
||||||
attention_mask_pos: The position of the attention mask
|
|
||||||
attention_mask_value: The value to pad to the attention mask
|
|
||||||
prompt_length: The length of the prompt tokens
|
|
||||||
only_prompt_trainable: Whether to train only prompt
|
|
||||||
attach_front: When set to True, prompt is attached in front of the embedding
|
|
||||||
extract_embedding: Whether the embedding is extracted at final stage to keep the same dims with inputs
|
|
||||||
pretrained_weights: The pretrained prompt weights. Can be a local dir, local file,
|
|
||||||
or a model id from modelscope
|
|
||||||
"""
|
|
||||||
|
|
||||||
dim: int = field(metadata={'help': 'The dimension of the hidden states'})
|
|
||||||
|
|
||||||
module_layer_name: str = field(
|
|
||||||
metadata={'help': 'The layer module to be replaced, in regex format'})
|
|
||||||
|
|
||||||
embedding_pos: Union[str, int] = field(
|
|
||||||
metadata={'help': 'The position of the embedding tensor'})
|
|
||||||
|
|
||||||
attention_mask_pos: Union[str, int] = field(
|
|
||||||
default=None, metadata={'help': 'The position of the attention mask'})
|
|
||||||
|
|
||||||
attention_mask_value: Union[float, int, bool] = field(
|
|
||||||
default=0.,
|
|
||||||
metadata={'help': 'The value to pad to the attention mask'})
|
|
||||||
|
|
||||||
prompt_length: int = field(
|
|
||||||
default=16, metadata={'help': 'The length of the prompt tokens'})
|
|
||||||
|
|
||||||
only_prompt_trainable: bool = field(
|
|
||||||
default=True, metadata={'help': 'Whether to train only prompt'})
|
|
||||||
|
|
||||||
attach_front: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'When set to True, prompt is attached in front of the embedding'
|
|
||||||
})
|
|
||||||
|
|
||||||
extract_embedding: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'Whether the embedding is extracted at final stage to keep the same dims with inputs'
|
|
||||||
})
|
|
||||||
|
|
||||||
pretrained_weights: str = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
'help':
|
|
||||||
'The pretrained prompt weights. Can be a local dir, local file, or a model id from modelscope'
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
class Prompt:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare_model(model: nn.Module, config: PromptConfig):
|
|
||||||
module_keys = [key for key, _ in model.named_modules()]
|
|
||||||
match_module_keys = []
|
|
||||||
for module_key in module_keys:
|
|
||||||
if re.fullmatch(config.module_layer_name, module_key): # noqa
|
|
||||||
module = model.get_submodule(module_key)
|
|
||||||
|
|
||||||
def _forward(self, *args, **kwargs):
|
|
||||||
if isinstance(config.embedding_pos, int):
|
|
||||||
input_embedding = args[config.embedding_pos]
|
|
||||||
else:
|
|
||||||
input_embedding = kwargs[config.embedding_pos]
|
|
||||||
|
|
||||||
input_embedding = getattr(
|
|
||||||
self, 'prompt').forward(input_embedding)
|
|
||||||
if isinstance(config.embedding_pos, int):
|
|
||||||
args = type(args)(
|
|
||||||
args[0:config.embedding_pos] + (input_embedding, )
|
|
||||||
+ args[config.embedding_pos + 1:])
|
|
||||||
else:
|
|
||||||
kwargs[config.embedding_pos] = input_embedding
|
|
||||||
|
|
||||||
if config.attention_mask_pos:
|
|
||||||
attention_mask = None
|
|
||||||
if isinstance(config.attention_mask_pos, int):
|
|
||||||
attention_mask = args[config.attention_mask_pos]
|
|
||||||
elif isinstance(config.attention_mask_pos, str):
|
|
||||||
attention_mask = kwargs[config.attention_mask_pos]
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = getattr(
|
|
||||||
self,
|
|
||||||
'prompt').patch_attention_mask(attention_mask)
|
|
||||||
if isinstance(config.attention_mask_pos, int):
|
|
||||||
args = type(args)(
|
|
||||||
args[0:config.attention_mask_pos]
|
|
||||||
+ (attention_mask, )
|
|
||||||
+ args[config.attention_mask_pos + 1:])
|
|
||||||
else:
|
|
||||||
kwargs[config.attention_mask_pos] = attention_mask
|
|
||||||
|
|
||||||
forward_output = self.forward_origin(*args, **kwargs)
|
|
||||||
if config.extract_embedding:
|
|
||||||
forward_output = getattr(
|
|
||||||
self, 'prompt').extract(forward_output)
|
|
||||||
|
|
||||||
return forward_output
|
|
||||||
|
|
||||||
module.forward_origin = module.forward
|
|
||||||
module.forward = types.MethodType(_forward, module)
|
|
||||||
|
|
||||||
if isinstance(config.dim, list):
|
|
||||||
input_dim = config.dim[len(match_module_keys)]
|
|
||||||
else:
|
|
||||||
input_dim = config.dim
|
|
||||||
|
|
||||||
prompt_module = PromptModule(input_dim,
|
|
||||||
int(module_key.rsplit('.')[-1]),
|
|
||||||
config.prompt_length,
|
|
||||||
config.attention_mask_value,
|
|
||||||
config.attach_front)
|
|
||||||
setattr(module, 'prompt', prompt_module)
|
|
||||||
match_module_keys.append(module_key)
|
|
||||||
|
|
||||||
if config.only_prompt_trainable:
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if 'prompt' not in n:
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
def state_dict_hook(module, destination, prefix, local_metadata):
|
|
||||||
return {
|
|
||||||
key: value
|
|
||||||
for key, value in destination.items() if 'prompt' in key
|
|
||||||
}
|
|
||||||
|
|
||||||
model.state_dict_hook_handle = model._register_state_dict_hook(
|
|
||||||
state_dict_hook)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
|
||||||
return self.load_state_dict_origin(state_dict, False)
|
|
||||||
|
|
||||||
model.load_state_dict_origin = model.load_state_dict
|
|
||||||
model.load_state_dict = types.MethodType(load_state_dict, model)
|
|
||||||
|
|
||||||
if config.pretrained_weights is not None:
|
|
||||||
if not os.path.exists(config.pretrained_weights):
|
|
||||||
model_dir = snapshot_download(config.pretrained_weights)
|
|
||||||
pretrained_weights = os.path.join(
|
|
||||||
model_dir, ModelFile.TORCH_MODEL_BIN_FILE)
|
|
||||||
elif os.path.isfile(config.pretrained_weights):
|
|
||||||
pretrained_weights = config.pretrained_weights
|
|
||||||
else:
|
|
||||||
pretrained_weights = os.path.join(
|
|
||||||
config.pretrained_weights, ModelFile.TORCH_MODEL_BIN_FILE)
|
|
||||||
model.load_state_dict(torch.load(pretrained_weights))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class PromptModule(nn.Module):
|
|
||||||
"""The implementation of vision prompt tuning method.
|
|
||||||
|
|
||||||
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
|
|
||||||
and prepend to the original tokens in the first layer or multiple layers.
|
|
||||||
'Visual Prompt Tuning' by Jia et al.(2022)
|
|
||||||
See https://arxiv.org/abs/2203.12119
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
dim: An integer indicating the embedding dimension.
|
|
||||||
layer_num: An integer indicating number of layers.
|
|
||||||
prompt_length: An integer indicating the length of vision prompt tuning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
dim,
|
|
||||||
layer_num,
|
|
||||||
prompt_length=None,
|
|
||||||
mask_values=0.,
|
|
||||||
attach_front=True):
|
|
||||||
super(PromptModule, self).__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.layer_num = layer_num
|
|
||||||
self.prompt_length = prompt_length
|
|
||||||
self.mask_values = mask_values
|
|
||||||
self.attach_front = attach_front
|
|
||||||
|
|
||||||
self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim))
|
|
||||||
nn.init.xavier_uniform_(self.prompt_token)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
prompt_token = self.prompt_token.expand(x.shape[0], -1, -1)
|
|
||||||
|
|
||||||
if self.layer_num == 0:
|
|
||||||
if self.attach_front:
|
|
||||||
x = torch.cat((prompt_token, x), dim=1)
|
|
||||||
else:
|
|
||||||
x = torch.cat((x, prompt_token), dim=1)
|
|
||||||
else:
|
|
||||||
if self.attach_front:
|
|
||||||
x = torch.cat((prompt_token, x[:, self.prompt_length:, :]),
|
|
||||||
dim=1)
|
|
||||||
else:
|
|
||||||
x = torch.cat((x[:, :-self.prompt_length, :], prompt_token),
|
|
||||||
dim=1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def patch_attention_mask(self, m):
|
|
||||||
prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length),
|
|
||||||
self.mask_values).to(m.device)
|
|
||||||
return torch.cat((prefix_attention_mask, m), dim=-1)
|
|
||||||
|
|
||||||
def extract(self, x):
|
|
||||||
if self.attach_front:
|
|
||||||
return x[:, self.prompt_length:, :]
|
|
||||||
else:
|
|
||||||
return x[:, :-self.prompt_length, :]
|
|
||||||
@@ -6,9 +6,7 @@ from torch import nn
|
|||||||
from modelscope.metainfo import Trainers
|
from modelscope.metainfo import Trainers
|
||||||
from modelscope.models.base import Model, TorchModel
|
from modelscope.models.base import Model, TorchModel
|
||||||
from modelscope.trainers.builder import TRAINERS
|
from modelscope.trainers.builder import TRAINERS
|
||||||
from modelscope.trainers.default_config import merge_hooks
|
|
||||||
from modelscope.trainers.trainer import EpochBasedTrainer
|
from modelscope.trainers.trainer import EpochBasedTrainer
|
||||||
from modelscope.utils.constant import ModeKeys
|
|
||||||
|
|
||||||
|
|
||||||
@TRAINERS.register_module(module_name=Trainers.vision_efficient_tuning)
|
@TRAINERS.register_module(module_name=Trainers.vision_efficient_tuning)
|
||||||
|
|||||||
@@ -328,7 +328,8 @@ class SiameseUIETrainer(EpochBasedTrainer):
|
|||||||
Example:
|
Example:
|
||||||
{"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
|
{"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
|
||||||
"""
|
"""
|
||||||
pipeline_uie = pipeline(Tasks.siamese_uie, self.model)
|
pipeline_uie = pipeline(
|
||||||
|
Tasks.siamese_uie, self.model, device=self.device)
|
||||||
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
||||||
from modelscope.trainers.hooks import LoadCheckpointHook
|
from modelscope.trainers.hooks import LoadCheckpointHook
|
||||||
LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
|
LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from modelscope.swift import ChildTuningAdamW
|
|
||||||
from .builder import OPTIMIZERS, build_optimizer
|
from .builder import OPTIMIZERS, build_optimizer
|
||||||
|
from .child_tuning_adamw_optimizer import ChildTuningAdamW
|
||||||
|
|
||||||
__all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW']
|
__all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW']
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
|
|||||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||||
from modelscope.outputs import ModelOutputBase
|
from modelscope.outputs import ModelOutputBase
|
||||||
from modelscope.preprocessors.base import Preprocessor
|
from modelscope.preprocessors.base import Preprocessor
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.trainers.hooks.builder import HOOKS
|
from modelscope.trainers.hooks.builder import HOOKS
|
||||||
from modelscope.trainers.hooks.priority import Priority, get_priority
|
from modelscope.trainers.hooks.priority import Priority, get_priority
|
||||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
|
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
|
||||||
@@ -41,6 +40,7 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
|
|||||||
from modelscope.utils.data_utils import to_device
|
from modelscope.utils.data_utils import to_device
|
||||||
from modelscope.utils.device import create_device
|
from modelscope.utils.device import create_device
|
||||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||||
|
from modelscope.utils.import_utils import is_swift_available
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from modelscope.utils.registry import build_from_cfg
|
from modelscope.utils.registry import build_from_cfg
|
||||||
from modelscope.utils.torch_utils import (compile_model, get_dist_info,
|
from modelscope.utils.torch_utils import (compile_model, get_dist_info,
|
||||||
@@ -54,6 +54,8 @@ from .hooks.hook import Hook
|
|||||||
from .parallel.builder import build_parallel
|
from .parallel.builder import build_parallel
|
||||||
from .parallel.utils import is_parallel
|
from .parallel.utils import is_parallel
|
||||||
|
|
||||||
|
TunerConfig = Union['swift.SwiftConfig', 'swift.PeftConfig']
|
||||||
|
|
||||||
|
|
||||||
@TRAINERS.register_module(module_name=Trainers.default)
|
@TRAINERS.register_module(module_name=Trainers.default)
|
||||||
class EpochBasedTrainer(BaseTrainer):
|
class EpochBasedTrainer(BaseTrainer):
|
||||||
@@ -118,7 +120,8 @@ class EpochBasedTrainer(BaseTrainer):
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
callbacks: Optional[List[Hook]] = None,
|
callbacks: Optional[List[Hook]] = None,
|
||||||
samplers: Optional[Union[Sampler, Dict[str, Sampler]]] = None,
|
samplers: Optional[Union[Sampler, Dict[str, Sampler]]] = None,
|
||||||
efficient_tuners: List[Dict] = None,
|
efficient_tuners: Union[Dict[str, TunerConfig],
|
||||||
|
TunerConfig] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
@@ -270,8 +273,12 @@ class EpochBasedTrainer(BaseTrainer):
|
|||||||
|
|
||||||
def tune_module(self, efficient_tuners):
|
def tune_module(self, efficient_tuners):
|
||||||
if efficient_tuners is not None:
|
if efficient_tuners is not None:
|
||||||
for tuner in efficient_tuners:
|
if not is_swift_available():
|
||||||
self.model = Swift.prepare_model(self.model, tuner)
|
raise ValueError(
|
||||||
|
'Please install swift by `pip install ms-swift` to use efficient_tuners.'
|
||||||
|
)
|
||||||
|
from swift import Swift
|
||||||
|
self.model = Swift.prepare_model(self.model, efficient_tuners)
|
||||||
|
|
||||||
def place_model(self):
|
def place_model(self):
|
||||||
"""Place model to device, or to DDP
|
"""Place model to device, or to DDP
|
||||||
|
|||||||
@@ -1,218 +0,0 @@
|
|||||||
# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
|
||||||
# The implementation is adopted from HighCWu,
|
|
||||||
# made pubicly available under the Apache License 2.0 License at https://github.com/HighCWu/ControlLoRA
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
||||||
from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer
|
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
|
||||||
from diffusers.utils.outputs import BaseOutput
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TunerOutput(BaseOutput):
|
|
||||||
lora_states: Tuple[torch.FloatTensor]
|
|
||||||
|
|
||||||
|
|
||||||
class LoRACrossAttnProcessor(nn.Module):
|
|
||||||
""" The implementation of lora attention module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
hidden_size,
|
|
||||||
cross_attention_dim=None,
|
|
||||||
rank=4,
|
|
||||||
post_add=False,
|
|
||||||
key_states_skipped=False,
|
|
||||||
value_states_skipped=False,
|
|
||||||
output_states_skipped=False):
|
|
||||||
""" Initialize a lora attn instance.
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`): The number of channels in embedding.
|
|
||||||
cross_attention_dim (`int`, *optional*):
|
|
||||||
The number of channels in the hidden_states. If not given, defaults to `hidden_size`.
|
|
||||||
rank (`int`, *optional*, defaults to 4): The number of rank of lora.
|
|
||||||
post_add (`bool`, *optional*, defaults to False): Set to `True`, conduct weighted
|
|
||||||
adding operation after lora.
|
|
||||||
key_states_skipped (`bool`, *optional*, defaults to False):
|
|
||||||
Set to `True` for skip to perform lora on key value.
|
|
||||||
value_states_skipped (`bool`, *optional*, defaults to False):
|
|
||||||
Set to `True` for skip to perform lora on value.
|
|
||||||
output_states_skipped (`bool`, *optional*, defaults to False):
|
|
||||||
Set to `True` for skip to perform lora on output value.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
|
||||||
self.rank = rank
|
|
||||||
self.post_add = post_add
|
|
||||||
|
|
||||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
||||||
if not key_states_skipped:
|
|
||||||
self.to_k_lora = LoRALinearLayer(
|
|
||||||
hidden_size if post_add else
|
|
||||||
(cross_attention_dim or hidden_size), hidden_size, rank)
|
|
||||||
if not value_states_skipped:
|
|
||||||
self.to_v_lora = LoRALinearLayer(
|
|
||||||
hidden_size if post_add else
|
|
||||||
(cross_attention_dim or hidden_size), hidden_size, rank)
|
|
||||||
if not output_states_skipped:
|
|
||||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
||||||
|
|
||||||
self.key_states_skipped: bool = key_states_skipped
|
|
||||||
self.value_states_skipped: bool = value_states_skipped
|
|
||||||
self.output_states_skipped: bool = output_states_skipped
|
|
||||||
|
|
||||||
def skip_key_states(self, is_skipped: bool = True):
|
|
||||||
if not is_skipped:
|
|
||||||
assert hasattr(self, 'to_k_lora')
|
|
||||||
self.key_states_skipped = is_skipped
|
|
||||||
|
|
||||||
def skip_value_states(self, is_skipped: bool = True):
|
|
||||||
if not is_skipped:
|
|
||||||
assert hasattr(self, 'to_q_lora')
|
|
||||||
self.value_states_skipped = is_skipped
|
|
||||||
|
|
||||||
def skip_output_states(self, is_skipped: bool = True):
|
|
||||||
if not is_skipped:
|
|
||||||
assert hasattr(self, 'to_out_lora')
|
|
||||||
self.output_states_skipped = is_skipped
|
|
||||||
|
|
||||||
def __call__(self,
|
|
||||||
attn: CrossAttention,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
scale=1.0):
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
attention_mask = attn.prepare_attention_mask(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
target_length=sequence_length,
|
|
||||||
batch_size=batch_size)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
query = query + scale * self.to_q_lora(
|
|
||||||
query if self.post_add else hidden_states)
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
|
|
||||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
if not self.key_states_skipped:
|
|
||||||
key = key + scale * self.to_k_lora(
|
|
||||||
key if self.post_add else encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
if not self.value_states_skipped:
|
|
||||||
value = value + scale * self.to_v_lora(
|
|
||||||
value if self.post_add else encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.head_to_batch_dim(key)
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
out = attn.to_out[0](hidden_states)
|
|
||||||
if not self.output_states_skipped:
|
|
||||||
out = out + scale * self.to_out_lora(
|
|
||||||
out if self.post_add else hidden_states)
|
|
||||||
hidden_states = out
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class LoRATuner(ModelMixin, ConfigMixin):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tune(
|
|
||||||
model: nn.Module,
|
|
||||||
tuner_config=None,
|
|
||||||
pretrained_tuner=None,
|
|
||||||
):
|
|
||||||
tuner = LoRATuner.from_config(tuner_config)
|
|
||||||
if pretrained_tuner is not None and os.path.exists(pretrained_tuner):
|
|
||||||
tuner.load_state_dict(
|
|
||||||
torch.load(pretrained_tuner, map_location='cpu'), strict=True)
|
|
||||||
tune_layers_list = list(
|
|
||||||
[list(layer_list) for layer_list in tuner.lora_layers])
|
|
||||||
assert hasattr(model, 'unet')
|
|
||||||
unet = model.unet
|
|
||||||
tuner.to(unet.device)
|
|
||||||
tune_attn_procs = tuner.set_tune_layers(unet, tune_layers_list)
|
|
||||||
unet.set_attn_processor(tune_attn_procs)
|
|
||||||
return tuner
|
|
||||||
|
|
||||||
def set_tune_layers(self, unet, tune_layers_list):
|
|
||||||
n_ch = len(unet.config.block_out_channels)
|
|
||||||
control_ids = [i for i in range(n_ch)]
|
|
||||||
tune_attn_procs = {}
|
|
||||||
|
|
||||||
for name in unet.attn_processors.keys():
|
|
||||||
if name.startswith('mid_block'):
|
|
||||||
control_id = control_ids[-1]
|
|
||||||
elif name.startswith('up_blocks'):
|
|
||||||
block_id = int(name[len('up_blocks.')])
|
|
||||||
control_id = list(reversed(control_ids))[block_id]
|
|
||||||
elif name.startswith('down_blocks'):
|
|
||||||
block_id = int(name[len('down_blocks.')])
|
|
||||||
control_id = control_ids[block_id]
|
|
||||||
|
|
||||||
tune_layers = tune_layers_list[control_id]
|
|
||||||
if len(tune_layers) != 0:
|
|
||||||
tune_layer = tune_layers.pop(0)
|
|
||||||
tune_attn_procs[name] = tune_layer
|
|
||||||
return tune_attn_procs
|
|
||||||
|
|
||||||
@register_to_config
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
lora_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
|
||||||
lora_cross_attention_dims: Tuple[List[int]] = ([
|
|
||||||
None, 768, None, 768, None, 768, None, 768, None, 768
|
|
||||||
], [None, 768, None, 768, None, 768, None, 768, None,
|
|
||||||
768], [None, 768, None, 768, None, 768, None, 768, None,
|
|
||||||
768], [None, 768]),
|
|
||||||
lora_rank: int = 4,
|
|
||||||
lora_post_add: bool = False,
|
|
||||||
lora_key_states_skipped: bool = False,
|
|
||||||
lora_value_states_skipped: bool = False,
|
|
||||||
lora_output_states_skipped: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
lora_cls = LoRACrossAttnProcessor
|
|
||||||
|
|
||||||
self.lora_layers = nn.ModuleList([])
|
|
||||||
|
|
||||||
for i, lora_cross_attention_dim in enumerate(
|
|
||||||
lora_cross_attention_dims):
|
|
||||||
self.lora_layers.append(
|
|
||||||
nn.ModuleList([
|
|
||||||
lora_cls(
|
|
||||||
lora_block_out_channels[i],
|
|
||||||
cross_attention_dim=cross_attention_dim,
|
|
||||||
rank=lora_rank,
|
|
||||||
post_add=lora_post_add,
|
|
||||||
key_states_skipped=lora_key_states_skipped,
|
|
||||||
value_states_skipped=lora_value_states_skipped,
|
|
||||||
output_states_skipped=lora_output_states_skipped)
|
|
||||||
for cross_attention_dim in lora_cross_attention_dim
|
|
||||||
]))
|
|
||||||
|
|
||||||
def forward(self) -> Union[TunerOutput, Tuple]:
|
|
||||||
lora_states_list = []
|
|
||||||
tune_layers_list = list(
|
|
||||||
[list(layer_list) for layer_list in self.lora_layers])
|
|
||||||
for tune_list in tune_layers_list:
|
|
||||||
for tune_layer in tune_list:
|
|
||||||
lora_states_list.append(tune_layer.to_q_lora.down.weight)
|
|
||||||
return TunerOutput(lora_states=tuple(lora_states_list))
|
|
||||||
@@ -246,6 +246,10 @@ def is_wenetruntime_available():
|
|||||||
return importlib.util.find_spec('wenetruntime') is not None
|
return importlib.util.find_spec('wenetruntime') is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_swift_available():
|
||||||
|
return importlib.util.find_spec('swift') is not None
|
||||||
|
|
||||||
|
|
||||||
def is_tf_available():
|
def is_tf_available():
|
||||||
return _tf_available
|
return _tf_available
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ datasets>=2.8.0,<=2.13.0
|
|||||||
einops
|
einops
|
||||||
filelock>=3.3.0
|
filelock>=3.3.0
|
||||||
gast>=0.2.2
|
gast>=0.2.2
|
||||||
|
ms-swift
|
||||||
numpy
|
numpy
|
||||||
oss2
|
oss2
|
||||||
pandas
|
pandas
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ from modelscope.metainfo import Preprocessors, Trainers
|
|||||||
from modelscope.models import Model
|
from modelscope.models import Model
|
||||||
from modelscope.msdatasets import MsDataset
|
from modelscope.msdatasets import MsDataset
|
||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
from modelscope.swift.optimizers.child_tuning_adamw_optimizer import \
|
|
||||||
calculate_fisher
|
|
||||||
from modelscope.trainers import build_trainer
|
from modelscope.trainers import build_trainer
|
||||||
from modelscope.trainers.hooks import Hook
|
from modelscope.trainers.hooks import Hook
|
||||||
from modelscope.trainers.nlp_trainer import (EpochBasedTrainer,
|
from modelscope.trainers.nlp_trainer import (EpochBasedTrainer,
|
||||||
NlpEpochBasedTrainer)
|
NlpEpochBasedTrainer)
|
||||||
|
from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \
|
||||||
|
calculate_fisher
|
||||||
from modelscope.trainers.training_args import TrainingArgs
|
from modelscope.trainers.training_args import TrainingArgs
|
||||||
from modelscope.utils.constant import ModelFile, Tasks
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
from modelscope.utils.data_utils import to_device
|
from modelscope.utils.data_utils import to_device
|
||||||
|
|||||||
@@ -6,11 +6,8 @@ import unittest
|
|||||||
|
|
||||||
from modelscope.metainfo import Trainers
|
from modelscope.metainfo import Trainers
|
||||||
from modelscope.msdatasets import MsDataset
|
from modelscope.msdatasets import MsDataset
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.adapter import AdapterConfig
|
|
||||||
from modelscope.swift.lora import LoRAConfig
|
|
||||||
from modelscope.swift.prompt import PromptConfig
|
|
||||||
from modelscope.trainers import build_trainer
|
from modelscope.trainers import build_trainer
|
||||||
|
from modelscope.utils.import_utils import is_swift_available
|
||||||
from modelscope.utils.test_utils import test_level
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
@@ -43,8 +40,10 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
shutil.rmtree(self.tmp_dir)
|
shutil.rmtree(self.tmp_dir)
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0 and is_swift_available(),
|
||||||
|
'skip test in current test level')
|
||||||
def test_vision_efficient_tuning_swift_lora_train(self):
|
def test_vision_efficient_tuning_swift_lora_train(self):
|
||||||
|
from swift import LoRAConfig
|
||||||
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
|
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
|
||||||
|
|
||||||
def cfg_modify_fn(cfg):
|
def cfg_modify_fn(cfg):
|
||||||
@@ -56,10 +55,9 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
rank=self.tune_length,
|
r=self.tune_length,
|
||||||
replace_modules=['qkv'],
|
target_modules=['qkv'],
|
||||||
merge_weights=False,
|
merge_weights=False,
|
||||||
only_lora_trainable=False,
|
|
||||||
use_merged_linear=True,
|
use_merged_linear=True,
|
||||||
enable_lora=[True])
|
enable_lora=[True])
|
||||||
|
|
||||||
@@ -69,7 +67,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
cfg_modify_fn=cfg_modify_fn,
|
cfg_modify_fn=cfg_modify_fn,
|
||||||
efficient_tuners=[lora_config])
|
efficient_tuners=lora_config)
|
||||||
|
|
||||||
trainer = build_trainer(
|
trainer = build_trainer(
|
||||||
name=Trainers.vision_efficient_tuning, default_args=kwargs)
|
name=Trainers.vision_efficient_tuning, default_args=kwargs)
|
||||||
@@ -82,8 +80,10 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
for i in range(self.max_epochs):
|
for i in range(self.max_epochs):
|
||||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0 and is_swift_available(),
|
||||||
|
'skip test in current test level')
|
||||||
def test_vision_efficient_tuning_swift_adapter_train(self):
|
def test_vision_efficient_tuning_swift_adapter_train(self):
|
||||||
|
from swift import AdapterConfig
|
||||||
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
|
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
|
||||||
|
|
||||||
def cfg_modify_fn(cfg):
|
def cfg_modify_fn(cfg):
|
||||||
@@ -97,9 +97,8 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
adapter_config = AdapterConfig(
|
adapter_config = AdapterConfig(
|
||||||
dim=768,
|
dim=768,
|
||||||
hidden_pos=0,
|
hidden_pos=0,
|
||||||
module_name=r'.*blocks\.\d+\.mlp$',
|
target_modules=r'.*blocks\.\d+\.mlp$',
|
||||||
adapter_length=self.tune_length,
|
adapter_length=self.tune_length)
|
||||||
only_adapter_trainable=False)
|
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
@@ -107,7 +106,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
cfg_modify_fn=cfg_modify_fn,
|
cfg_modify_fn=cfg_modify_fn,
|
||||||
efficient_tuners=[adapter_config])
|
efficient_tuners=adapter_config)
|
||||||
|
|
||||||
trainer = build_trainer(
|
trainer = build_trainer(
|
||||||
name=Trainers.vision_efficient_tuning, default_args=kwargs)
|
name=Trainers.vision_efficient_tuning, default_args=kwargs)
|
||||||
@@ -120,8 +119,10 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
for i in range(self.max_epochs):
|
for i in range(self.max_epochs):
|
||||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0 and is_swift_available(),
|
||||||
|
'skip test in current test level')
|
||||||
def test_vision_efficient_tuning_swift_prompt_train(self):
|
def test_vision_efficient_tuning_swift_prompt_train(self):
|
||||||
|
from swift import PromptConfig
|
||||||
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
|
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
|
||||||
|
|
||||||
def cfg_modify_fn(cfg):
|
def cfg_modify_fn(cfg):
|
||||||
@@ -134,10 +135,9 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
|
|
||||||
prompt_config = PromptConfig(
|
prompt_config = PromptConfig(
|
||||||
dim=768,
|
dim=768,
|
||||||
module_layer_name=r'.*blocks\.\d+$',
|
target_modules=r'.*blocks\.\d+$',
|
||||||
embedding_pos=0,
|
embedding_pos=0,
|
||||||
prompt_length=self.tune_length,
|
prompt_length=self.tune_length,
|
||||||
only_prompt_trainable=False,
|
|
||||||
attach_front=False)
|
attach_front=False)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
@@ -146,7 +146,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
cfg_modify_fn=cfg_modify_fn,
|
cfg_modify_fn=cfg_modify_fn,
|
||||||
efficient_tuners=[prompt_config])
|
efficient_tuners=prompt_config)
|
||||||
|
|
||||||
trainer = build_trainer(
|
trainer = build_trainer(
|
||||||
name=Trainers.vision_efficient_tuning, default_args=kwargs)
|
name=Trainers.vision_efficient_tuning, default_args=kwargs)
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modelscope import read_config
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
|
||||||
from modelscope.models.base import Model
|
|
||||||
from modelscope.msdatasets import MsDataset
|
|
||||||
from modelscope.pipelines import pipeline
|
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.adapter import AdapterConfig
|
|
||||||
from modelscope.trainers import build_trainer
|
|
||||||
from modelscope.utils.constant import ModelFile, Tasks
|
|
||||||
from modelscope.utils.test_utils import test_level
|
|
||||||
|
|
||||||
|
|
||||||
class TestAdapter(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
||||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
|
||||||
if not os.path.exists(self.tmp_dir):
|
|
||||||
os.makedirs(self.tmp_dir)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
shutil.rmtree(self.tmp_dir)
|
|
||||||
super().tearDown()
|
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip in this level')
|
|
||||||
def test_adapter_smoke_test(self):
|
|
||||||
dataset = MsDataset.load(
|
|
||||||
'clue', subset_name='afqmc',
|
|
||||||
split='train').to_hf_dataset().select(range(2))
|
|
||||||
|
|
||||||
model_dir = snapshot_download(
|
|
||||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
|
||||||
model = Model.from_pretrained(model_dir, adv_grad_factor=None)
|
|
||||||
|
|
||||||
cfg_file = os.path.join(model_dir, 'configuration.json')
|
|
||||||
|
|
||||||
model_cfg = os.path.join(model_dir, 'config.json')
|
|
||||||
model_cfg = read_config(model_cfg)
|
|
||||||
|
|
||||||
adapter_config = AdapterConfig(
|
|
||||||
dim=model_cfg.hidden_size,
|
|
||||||
module_name=r'.*layer\.\d+$',
|
|
||||||
method_name='feed_forward_chunk',
|
|
||||||
hidden_pos=0)
|
|
||||||
model = Swift.prepare_model(model, adapter_config)
|
|
||||||
kwargs = dict(
|
|
||||||
model=model,
|
|
||||||
cfg_file=cfg_file,
|
|
||||||
train_dataset=dataset,
|
|
||||||
eval_dataset=dataset,
|
|
||||||
work_dir=self.tmp_dir)
|
|
||||||
|
|
||||||
trainer = build_trainer(default_args=kwargs)
|
|
||||||
trainer.train()
|
|
||||||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
|
|
||||||
|
|
||||||
def pipeline_sentence_similarity(model_dir):
|
|
||||||
model = Model.from_pretrained(model_dir)
|
|
||||||
adapter_config.pretrained_weights = output_dir
|
|
||||||
Swift.prepare_model(model, adapter_config)
|
|
||||||
model.eval()
|
|
||||||
pipeline_ins = pipeline(
|
|
||||||
task=Tasks.sentence_similarity, model=model)
|
|
||||||
return pipeline_ins(input=('test', 'this is a test'))
|
|
||||||
|
|
||||||
output1 = pipeline_sentence_similarity(
|
|
||||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
|
||||||
print(output1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
|
||||||
from modelscope.models.base import Model
|
|
||||||
from modelscope.msdatasets import MsDataset
|
|
||||||
from modelscope.pipelines import pipeline
|
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.lora import (Linear, LoRA, LoRAConfig,
|
|
||||||
mark_only_lora_as_trainable)
|
|
||||||
from modelscope.trainers import build_trainer
|
|
||||||
from modelscope.utils.constant import ModelFile, Tasks
|
|
||||||
from modelscope.utils.test_utils import test_level
|
|
||||||
|
|
||||||
|
|
||||||
class TestLora(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
||||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
|
||||||
if not os.path.exists(self.tmp_dir):
|
|
||||||
os.makedirs(self.tmp_dir)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
shutil.rmtree(self.tmp_dir)
|
|
||||||
super().tearDown()
|
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip in this level')
|
|
||||||
def test_lora_base(self):
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.lora = Linear(16, 16, r=4)
|
|
||||||
|
|
||||||
model = TestModel()
|
|
||||||
mark_only_lora_as_trainable(model)
|
|
||||||
model.train()
|
|
||||||
loss = model.lora(torch.ones(16, 16))
|
|
||||||
loss = loss.sum()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
model = TestModel()
|
|
||||||
mark_only_lora_as_trainable(model)
|
|
||||||
model.eval()
|
|
||||||
loss = model.lora(torch.ones(16, 16))
|
|
||||||
loss = loss.sum()
|
|
||||||
try:
|
|
||||||
loss.backward()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise Exception('No tensor needs grad, should throw en error here')
|
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip in this level')
|
|
||||||
def test_lora_smoke_test(self):
|
|
||||||
dataset = MsDataset.load(
|
|
||||||
'clue', subset_name='afqmc',
|
|
||||||
split='train').to_hf_dataset().select(range(2))
|
|
||||||
|
|
||||||
model_dir = snapshot_download(
|
|
||||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
|
||||||
model = Model.from_pretrained(model_dir, adv_grad_factor=None)
|
|
||||||
|
|
||||||
cfg_file = os.path.join(model_dir, 'configuration.json')
|
|
||||||
lora_config = LoRAConfig(replace_modules=['query', 'key', 'value'])
|
|
||||||
model = Swift.prepare_model(model, lora_config)
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
model=model,
|
|
||||||
cfg_file=cfg_file,
|
|
||||||
train_dataset=dataset,
|
|
||||||
eval_dataset=dataset,
|
|
||||||
work_dir=self.tmp_dir)
|
|
||||||
|
|
||||||
trainer = build_trainer(default_args=kwargs)
|
|
||||||
trainer.train()
|
|
||||||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
|
|
||||||
|
|
||||||
def pipeline_sentence_similarity(model_dir):
|
|
||||||
model = Model.from_pretrained(model_dir)
|
|
||||||
lora_config.pretrained_weights = output_dir
|
|
||||||
Swift.prepare_model(model, lora_config)
|
|
||||||
model.load_state_dict(
|
|
||||||
torch.load(os.path.join(output_dir, 'pytorch_model.bin')))
|
|
||||||
model.eval()
|
|
||||||
pipeline_ins = pipeline(
|
|
||||||
task=Tasks.sentence_similarity, model=model)
|
|
||||||
return pipeline_ins(input=('test', 'this is a test'))
|
|
||||||
|
|
||||||
output1 = pipeline_sentence_similarity(
|
|
||||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
|
||||||
|
|
||||||
LoRA.unpatch_lora(model, lora_config)
|
|
||||||
model.save_pretrained(
|
|
||||||
output_dir, save_checkpoint_names='pytorch_model.bin')
|
|
||||||
|
|
||||||
def pipeline_sentence_similarity_origin():
|
|
||||||
model = Model.from_pretrained(output_dir)
|
|
||||||
model.eval()
|
|
||||||
pipeline_ins = pipeline(
|
|
||||||
task=Tasks.sentence_similarity, model=model)
|
|
||||||
return pipeline_ins(input=('test', 'this is a test'))
|
|
||||||
|
|
||||||
output2 = pipeline_sentence_similarity_origin()
|
|
||||||
print(output1, output2)
|
|
||||||
self.assertTrue(all(np.isclose(output1['scores'], output2['scores'])))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modelscope import read_config
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
|
||||||
from modelscope.models.base import Model
|
|
||||||
from modelscope.msdatasets import MsDataset
|
|
||||||
from modelscope.pipelines import pipeline
|
|
||||||
from modelscope.swift import Swift
|
|
||||||
from modelscope.swift.adapter import AdapterConfig
|
|
||||||
from modelscope.swift.prompt import PromptConfig
|
|
||||||
from modelscope.trainers import build_trainer
|
|
||||||
from modelscope.utils.constant import ModelFile, Tasks
|
|
||||||
from modelscope.utils.test_utils import test_level
|
|
||||||
|
|
||||||
|
|
||||||
class TestPrompt(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
||||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
|
||||||
if not os.path.exists(self.tmp_dir):
|
|
||||||
os.makedirs(self.tmp_dir)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
shutil.rmtree(self.tmp_dir)
|
|
||||||
super().tearDown()
|
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip in this level')
|
|
||||||
def test_prompt_smoke_test(self):
|
|
||||||
dataset = MsDataset.load(
|
|
||||||
'clue', subset_name='afqmc',
|
|
||||||
split='train').to_hf_dataset().select(range(2))
|
|
||||||
|
|
||||||
model_dir = snapshot_download(
|
|
||||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
|
||||||
model = Model.from_pretrained(model_dir, adv_grad_factor=None)
|
|
||||||
|
|
||||||
cfg_file = os.path.join(model_dir, 'configuration.json')
|
|
||||||
model_cfg = os.path.join(model_dir, 'config.json')
|
|
||||||
model_cfg = read_config(model_cfg)
|
|
||||||
|
|
||||||
prompt_config = PromptConfig(
|
|
||||||
dim=model_cfg.hidden_size,
|
|
||||||
module_layer_name=r'.*layer\.\d+$',
|
|
||||||
embedding_pos=0,
|
|
||||||
attention_mask_pos=1)
|
|
||||||
|
|
||||||
model = Swift.prepare_model(model, prompt_config)
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
model=model,
|
|
||||||
cfg_file=cfg_file,
|
|
||||||
train_dataset=dataset,
|
|
||||||
eval_dataset=dataset,
|
|
||||||
work_dir=self.tmp_dir)
|
|
||||||
|
|
||||||
trainer = build_trainer(default_args=kwargs)
|
|
||||||
trainer.train()
|
|
||||||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
|
|
||||||
|
|
||||||
def pipeline_sentence_similarity(model_dir):
|
|
||||||
model = Model.from_pretrained(model_dir)
|
|
||||||
prompt_config.pretrained_weights = output_dir
|
|
||||||
Swift.prepare_model(model, prompt_config)
|
|
||||||
model.eval()
|
|
||||||
pipeline_ins = pipeline(
|
|
||||||
task=Tasks.sentence_similarity, model=model)
|
|
||||||
return pipeline_ins(input=('test', 'this is a test'))
|
|
||||||
|
|
||||||
output1 = pipeline_sentence_similarity(
|
|
||||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
|
||||||
print(output1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user