Files
modelscope/tools/weight_diff.py

171 lines
6.1 KiB
Python
Raw Normal View History

# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Copyright (c) Alibaba, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from typing import Dict, Optional
import torch
import tqdm
import transformers
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.utils.checkpoint import save_pretrained
from modelscope.utils.logger import get_logger
logger = get_logger()
def smart_tokenizer_and_embedding_resize(special_tokens_dict: Dict, tokenizer,
model):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def make_same_shape(model_raw: Model, model_convert: Model, tokenizer_raw,
tokenizer_convert):
if model_raw.__class__ != model_convert.__class__:
logger.error(
f'weight diff: These two models should be of the same class. model_raw:'
f'{model_raw.__class__} vs model_convert: {model_convert.__class__}.'
)
special_tokens = {}
for k, v in tokenizer_convert.special_tokens_map_extended.items():
if k not in tokenizer_raw.special_tokens_map_extended:
special_tokens[k] = v
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens,
model=model_raw,
tokenizer=tokenizer_raw,
)
state_dict_tuned = model_convert.state_dict()
state_dict_raw = model_raw.state_dict()
for key in tqdm.tqdm(state_dict_tuned):
if state_dict_tuned[key].shape != state_dict_raw[key].shape:
logger.error(
f'weight diff: shape mismatch. {key}, model_raw shape: {state_dict_raw[key].shape}'
f' vs model_convert shape: {state_dict_tuned[key].shape}.')
def _weight_diff(model_raw,
model_convert,
tokenizer_raw,
tokenizer_convert,
path_to_save=None,
make_diff_or_recover='diff'):
make_same_shape(model_raw, model_convert, tokenizer_raw, tokenizer_convert)
state_dict_raw = model_raw.state_dict()
state_dict_convert = model_convert.state_dict()
if make_diff_or_recover == 'diff':
for key in tqdm.tqdm(state_dict_convert):
state_dict_convert[key].add_(-state_dict_raw[key])
elif make_diff_or_recover == 'recover':
for key in tqdm.tqdm(state_dict_convert):
state_dict_convert[key].add_(state_dict_raw[key])
if path_to_save:
model_convert.save_pretrained(path_to_save, 'pytorch_model.bin')
tokenizer_convert.save_pretrained(path_to_save)
return model_convert, tokenizer_convert
@torch.inference_mode()
def weight_diff(path_raw: str,
path_convert: str,
path_to_save: str,
make_diff_or_recover,
device='cpu'):
"""Make the weight diff.
This function is given to present full transparency of how the weight diff was created.
"""
if not os.path.exists(path_raw):
logger.info(
f'Path `{path_raw}` not found. Try to load from cache or remote.')
path_raw = snapshot_download(path_raw)
if not os.path.exists(path_convert):
logger.info(
f'Path `{path_convert}` not found. Try to load from cache or remote.'
)
path_convert = snapshot_download(path_convert)
model_raw = Model.from_pretrained(path_raw, device=device)
model_convert = Model.from_pretrained(path_convert, device=device)
tokenizer_raw: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
path_raw)
tokenizer_convert: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
path_convert)
return _weight_diff(
model_raw,
model_convert,
tokenizer_raw,
tokenizer_convert,
path_to_save=path_to_save,
make_diff_or_recover=make_diff_or_recover)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=
'Make the weight diff between the raw model and tuned model, or recover tuned weights from the '
'released weight diff.')
parser.add_argument(
'make_diff_or_recover',
choices=['diff', 'recover'],
help=
'model selection, make weight diff or recover weights from the weight diff.'
)
parser.add_argument(
'path_raw', type=str, help='path to the raw pretrained model.')
parser.add_argument(
'path_convert',
type=str,
help=
'path to the tuned model in mode `diff`, or path to the diff model in mode `recover`.'
)
parser.add_argument(
'path_to_save',
type=str,
help='path to save the diff or recover output files.')
args = parser.parse_args()
weight_diff(args.path_raw, args.path_convert, args.path_to_save,
args.make_diff_or_recover)