Files
modelscope/examples/pytorch/llm/utils/models.py
2023-08-29 22:03:52 +08:00

233 lines
8.3 KiB
Python

import os
from types import MethodType
from typing import Any, Dict, NamedTuple, Optional
import torch
from swift import get_logger
from torch import dtype as Dtype
from modelscope import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, Model,
read_config, snapshot_download)
from modelscope.models.nlp.chatglm2 import ChatGLM2Config, ChatGLM2Tokenizer
from modelscope.models.nlp.llama2 import Llama2Config, Llama2Tokenizer
logger = get_logger()
def _add_special_token(tokenizer, special_token_mapper: Dict[str,
Any]) -> None:
for k, v in special_token_mapper.items():
setattr(tokenizer, k, v)
assert tokenizer.eos_token is not None
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def get_model_tokenizer_from_repo(model_dir: str,
torch_dtype: Dtype,
load_model: bool = True,
model_config=None,
**model_kwargs):
"""load from an independent repository"""
if model_config is None:
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
model_config.torch_dtype = torch_dtype
logger.info(f'model_config: {model_config}')
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
model = None
if load_model:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
config=model_config,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs)
return model, tokenizer
def get_model_tokenizer_from_sdk(config_class: type,
tokenizer_class: type,
model_dir: str,
torch_dtype: Dtype,
load_model: bool = True,
model_config=None,
**model_kwargs):
"""load from ms library"""
config = read_config(model_dir)
logger.info(config)
if model_config is None:
model_config = config_class.from_pretrained(model_dir)
model_config.torch_dtype = torch_dtype
logger.info(model_config)
tokenizer = tokenizer_class.from_pretrained(model_dir)
model = None
if load_model:
model = Model.from_pretrained(
model_dir,
cfg_dict=config,
config=model_config,
torch_dtype=torch_dtype,
**model_kwargs)
return model, tokenizer
def get_model_tokenizer_baichuan13b(model_dir: str,
torch_dtype: Dtype,
load_model: bool = True,
**model_kwargs):
# baichuan-13b does not implement the `get_input_embeddings` function
model, tokenizer = get_model_tokenizer_from_repo(model_dir, torch_dtype,
load_model,
**model_kwargs)
model.get_input_embeddings = MethodType(
lambda self: self.model.embed_tokens, model)
return model, tokenizer
def get_model_tokenizer_chatglm2(model_dir: str,
torch_dtype: Dtype,
load_model: bool = True,
**model_kwargs):
if 'quantization_config' in model_kwargs:
model_kwargs['quantization_config'].llm_int8_skip_modules = [
'output_layer'
]
return get_model_tokenizer_from_sdk(ChatGLM2Config, ChatGLM2Tokenizer,
model_dir, torch_dtype, load_model,
**model_kwargs)
def get_model_tokenizer_llama2(model_dir: str,
torch_dtype: Dtype,
load_model: bool = True,
**model_kwargs):
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
model_config.pretraining_tp = 1
return get_model_tokenizer_from_sdk(Llama2Config, Llama2Tokenizer,
model_dir, torch_dtype, load_model,
model_config, **model_kwargs)
def get_model_tokenizer_qwen(model_dir: str,
torch_dtype: Dtype,
load_model: bool = True,
**kwargs):
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
mapper = {
torch.float16: 'fp16',
torch.bfloat16: 'bf16',
torch.float32: 'fp32'
}
k_true = mapper[torch_dtype]
for k in mapper.values():
v = False
if k == k_true:
v = True
setattr(model_config, k, v)
use_flash_attn = kwargs.pop('use_flash_attn', 'auto')
model_config.use_flash_attn = use_flash_attn
return get_model_tokenizer_from_repo(model_dir, torch_dtype, load_model,
model_config, **kwargs)
class LoRATM(NamedTuple):
# default lora target modules
baichuan = ['W_pack']
chatglm2 = ['query_key_value']
llama2 = ['q_proj', 'k_proj', 'v_proj']
qwen = ['c_attn']
polylm = ['c_attn']
# Reference: 'https://modelscope.cn/models/{model_id}/summary'
# keys: 'model_id', 'revision', 'get_function',
# 'ignore_file_pattern', 'special_token_mapper', 'lora_TM'
MODEL_MAPPING = {
'baichuan-7b': {
'model_id': 'baichuan-inc/baichuan-7B', # model id or model dir
'revision': 'v1.0.7',
'lora_TM': LoRATM.baichuan
},
'baichuan-13b': {
'model_id': 'baichuan-inc/Baichuan-13B-Base',
'revision': 'v1.0.3',
'get_function': get_model_tokenizer_baichuan13b,
'lora_TM': LoRATM.baichuan
},
'chatglm2-6b': {
'model_id': 'ZhipuAI/chatglm2-6b',
'revision': 'v1.0.7',
'get_function': get_model_tokenizer_chatglm2,
'lora_TM': LoRATM.chatglm2
},
'llama2-7b': {
'model_id': 'modelscope/Llama-2-7b-ms',
'revision': 'v1.0.2',
'get_function': get_model_tokenizer_llama2,
'ignore_file_pattern': [r'.+\.bin$'], # use safetensors
'lora_TM': LoRATM.llama2
},
'llama2-13b': {
'model_id': 'modelscope/Llama-2-13b-ms',
'revision': 'v1.0.2',
'get_function': get_model_tokenizer_llama2,
'ignore_file_pattern': [r'.+\.bin$'],
'lora_TM': LoRATM.llama2
},
'llama2-70b': {
'model_id': 'modelscope/Llama-2-70b-ms',
'revision': 'v1.0.0',
'get_function': get_model_tokenizer_llama2,
'ignore_file_pattern': [r'.+\.bin$'],
'lora_TM': LoRATM.llama2
},
'openbuddy-llama2-13b': {
'model_id': 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16',
'revision': 'v1.0.0',
'lora_TM': LoRATM.llama2,
},
'qwen-7b': {
'model_id': 'qwen/Qwen-7B',
'revision': 'v.1.0.4',
'get_function': get_model_tokenizer_qwen,
'lora_TM': LoRATM.qwen,
'special_token_mapper': {
'eos_token': '<|endoftext|>'
}
}
}
def get_model_tokenizer(model_type: str,
torch_dtype: Optional[Dtype] = None,
load_model: bool = True,
**kwargs):
data = MODEL_MAPPING.get(model_type)
if data is None:
raise ValueError(f'model_type: {model_type}')
model_id = data['model_id']
get_function = data.get('get_function', get_model_tokenizer_from_repo)
ignore_file_pattern = data.get('ignore_file_pattern', [])
special_token_mapper = data.get('special_token_mapper', {})
if torch_dtype is None:
torch_dtype = data.get('torch_dtype', torch.float16)
model_dir = kwargs.pop('model_dir', None)
if model_dir is None:
model_dir = model_id
if not os.path.exists(model_id):
revision = data.get('revision', 'master')
model_dir = snapshot_download(
model_id, revision, ignore_file_pattern=ignore_file_pattern)
model, tokenizer = get_function(model_dir, torch_dtype, load_model,
**kwargs)
_add_special_token(tokenizer, special_token_mapper)
return model, tokenizer, model_dir