2023-06-26 19:41:19 +08:00
|
|
|
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import argparse
|
2023-06-29 10:26:52 +08:00
|
|
|
import os
|
2023-06-26 19:41:19 +08:00
|
|
|
from typing import Dict, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import tqdm
|
|
|
|
|
import transformers
|
|
|
|
|
|
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
|
|
|
from modelscope.models import Model
|
|
|
|
|
from modelscope.utils.checkpoint import save_pretrained
|
|
|
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smart_tokenizer_and_embedding_resize(special_tokens_dict: Dict, tokenizer,
|
|
|
|
|
model):
|
|
|
|
|
"""Resize tokenizer and embedding.
|
|
|
|
|
|
|
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
|
|
|
|
"""
|
|
|
|
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
|
|
|
|
if num_new_tokens > 0:
|
|
|
|
|
input_embeddings = model.get_input_embeddings().weight.data
|
|
|
|
|
output_embeddings = model.get_output_embeddings().weight.data
|
|
|
|
|
|
|
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
|
|
|
|
dim=0, keepdim=True)
|
|
|
|
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
|
|
|
|
dim=0, keepdim=True)
|
|
|
|
|
|
|
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
|
|
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_same_shape(model_raw: Model, model_convert: Model, tokenizer_raw,
|
|
|
|
|
tokenizer_convert):
|
|
|
|
|
if model_raw.__class__ != model_convert.__class__:
|
|
|
|
|
logger.error(
|
|
|
|
|
f'weight diff: These two models should be of the same class. model_raw:'
|
|
|
|
|
f'{model_raw.__class__} vs model_convert: {model_convert.__class__}.'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
special_tokens = {}
|
|
|
|
|
for k, v in tokenizer_convert.special_tokens_map_extended.items():
|
|
|
|
|
if k not in tokenizer_raw.special_tokens_map_extended:
|
|
|
|
|
special_tokens[k] = v
|
|
|
|
|
|
|
|
|
|
smart_tokenizer_and_embedding_resize(
|
|
|
|
|
special_tokens_dict=special_tokens,
|
|
|
|
|
model=model_raw,
|
|
|
|
|
tokenizer=tokenizer_raw,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
state_dict_tuned = model_convert.state_dict()
|
|
|
|
|
state_dict_raw = model_raw.state_dict()
|
|
|
|
|
for key in tqdm.tqdm(state_dict_tuned):
|
|
|
|
|
if state_dict_tuned[key].shape != state_dict_raw[key].shape:
|
|
|
|
|
logger.error(
|
|
|
|
|
f'weight diff: shape mismatch. {key}, model_raw shape: {state_dict_raw[key].shape}'
|
|
|
|
|
f' vs model_convert shape: {state_dict_tuned[key].shape}.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _weight_diff(model_raw,
|
|
|
|
|
model_convert,
|
|
|
|
|
tokenizer_raw,
|
|
|
|
|
tokenizer_convert,
|
|
|
|
|
path_to_save=None,
|
|
|
|
|
make_diff_or_recover='diff'):
|
|
|
|
|
make_same_shape(model_raw, model_convert, tokenizer_raw, tokenizer_convert)
|
|
|
|
|
|
|
|
|
|
state_dict_raw = model_raw.state_dict()
|
|
|
|
|
state_dict_convert = model_convert.state_dict()
|
|
|
|
|
if make_diff_or_recover == 'diff':
|
|
|
|
|
for key in tqdm.tqdm(state_dict_convert):
|
|
|
|
|
state_dict_convert[key].add_(-state_dict_raw[key])
|
|
|
|
|
elif make_diff_or_recover == 'recover':
|
|
|
|
|
for key in tqdm.tqdm(state_dict_convert):
|
|
|
|
|
state_dict_convert[key].add_(state_dict_raw[key])
|
|
|
|
|
|
|
|
|
|
if path_to_save:
|
|
|
|
|
model_convert.save_pretrained(path_to_save, 'pytorch_model.bin')
|
|
|
|
|
tokenizer_convert.save_pretrained(path_to_save)
|
|
|
|
|
|
|
|
|
|
return model_convert, tokenizer_convert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def weight_diff(path_raw: str,
|
|
|
|
|
path_convert: str,
|
|
|
|
|
path_to_save: str,
|
|
|
|
|
make_diff_or_recover,
|
|
|
|
|
device='cpu'):
|
|
|
|
|
"""Make the weight diff.
|
|
|
|
|
|
|
|
|
|
This function is given to present full transparency of how the weight diff was created.
|
|
|
|
|
"""
|
2023-06-29 10:26:52 +08:00
|
|
|
if not os.path.exists(path_raw):
|
|
|
|
|
logger.info(
|
|
|
|
|
f'Path `{path_raw}` not found. Try to load from cache or remote.')
|
|
|
|
|
path_raw = snapshot_download(path_raw)
|
|
|
|
|
if not os.path.exists(path_convert):
|
|
|
|
|
logger.info(
|
|
|
|
|
f'Path `{path_convert}` not found. Try to load from cache or remote.'
|
|
|
|
|
)
|
|
|
|
|
path_convert = snapshot_download(path_convert)
|
|
|
|
|
|
2023-06-26 19:41:19 +08:00
|
|
|
model_raw = Model.from_pretrained(path_raw, device=device)
|
|
|
|
|
model_convert = Model.from_pretrained(path_convert, device=device)
|
|
|
|
|
|
|
|
|
|
tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
|
|
|
path_raw)
|
|
|
|
|
tokenizer_convert: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
|
|
|
path_convert)
|
|
|
|
|
|
|
|
|
|
return _weight_diff(
|
|
|
|
|
model_raw,
|
|
|
|
|
model_convert,
|
|
|
|
|
tokenizer_raw,
|
|
|
|
|
tokenizer_convert,
|
|
|
|
|
path_to_save=path_to_save,
|
|
|
|
|
make_diff_or_recover=make_diff_or_recover)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
description=
|
|
|
|
|
'Make the weight diff between the raw model and tuned model, or recover tuned weights from the '
|
|
|
|
|
'released weight diff.')
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'make_diff_or_recover',
|
|
|
|
|
choices=['diff', 'recover'],
|
|
|
|
|
help=
|
|
|
|
|
'model selection, make weight diff or recover weights from the weight diff.'
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'path_raw', type=str, help='path to the raw pretrained model.')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'path_convert',
|
|
|
|
|
type=str,
|
|
|
|
|
help=
|
|
|
|
|
'path to the tuned model in mode `diff`, or path to the diff model in mode `recover`.'
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'path_to_save',
|
|
|
|
|
type=str,
|
|
|
|
|
help='path to save the diff or recover output files.')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
weight_diff(args.path_raw, args.path_convert, args.path_to_save,
|
|
|
|
|
args.make_diff_or_recover)
|