2023-07-29 00:06:27 +08:00
|
|
|
import os
|
|
|
|
|
from typing import Any, Dict, NamedTuple, Optional
|
2023-07-26 18:12:55 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import dtype as Dtype
|
|
|
|
|
|
|
|
|
|
from modelscope import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, Model,
|
|
|
|
|
get_logger, read_config, snapshot_download)
|
|
|
|
|
from modelscope.models.nlp.chatglm2 import ChatGLM2Config, ChatGLM2Tokenizer
|
2023-08-02 09:25:21 +08:00
|
|
|
from modelscope.models.nlp.qwen import QWenConfig, QWenTokenizer
|
2023-07-26 18:12:55 +08:00
|
|
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
|
|
2023-07-29 00:06:27 +08:00
|
|
|
def _add_special_token(tokenizer, special_token_mapper: Dict[str,
|
|
|
|
|
Any]) -> None:
|
|
|
|
|
for k, v in special_token_mapper:
|
|
|
|
|
setattr(tokenizer, k, v)
|
|
|
|
|
assert tokenizer.eos_token is not None
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
2023-07-26 18:12:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_tokenizer_default(model_dir: str,
|
2023-07-29 00:06:27 +08:00
|
|
|
torch_dtype: Dtype,
|
|
|
|
|
load_model: bool = True):
|
2023-07-26 18:12:55 +08:00
|
|
|
"""load from an independent repository"""
|
|
|
|
|
model_config = AutoConfig.from_pretrained(
|
|
|
|
|
model_dir, trust_remote_code=True)
|
|
|
|
|
model_config.torch_dtype = torch_dtype
|
|
|
|
|
logger.info(f'model_config: {model_config}')
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
model_dir, trust_remote_code=True)
|
|
|
|
|
model = None
|
|
|
|
|
if load_model:
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
model_dir,
|
|
|
|
|
config=model_config,
|
|
|
|
|
device_map='auto',
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
trust_remote_code=True)
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
2023-08-10 20:57:47 +08:00
|
|
|
def get_model_tokenizer_polylm(model_dir: str,
|
|
|
|
|
torch_dtype: Dtype,
|
|
|
|
|
load_model: bool = True):
|
|
|
|
|
"""load from an independent repository"""
|
|
|
|
|
model_config = AutoConfig.from_pretrained(
|
|
|
|
|
model_dir, trust_remote_code=True)
|
|
|
|
|
model_config.torch_dtype = torch_dtype
|
|
|
|
|
logger.info(f'model_config: {model_config}')
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
|
|
|
|
|
model = None
|
|
|
|
|
if load_model:
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
model_dir,
|
|
|
|
|
config=model_config,
|
|
|
|
|
device_map='auto',
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
trust_remote_code=True)
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
2023-07-26 18:12:55 +08:00
|
|
|
def get_model_tokenizer_chatglm2(model_dir: str,
|
2023-07-29 00:06:27 +08:00
|
|
|
torch_dtype: Dtype,
|
|
|
|
|
load_model: bool = True):
|
2023-07-26 18:12:55 +08:00
|
|
|
"""load from ms library"""
|
|
|
|
|
config = read_config(model_dir)
|
|
|
|
|
logger.info(config)
|
|
|
|
|
model_config = ChatGLM2Config.from_pretrained(model_dir)
|
|
|
|
|
model_config.torch_dtype = torch_dtype
|
|
|
|
|
logger.info(model_config)
|
|
|
|
|
tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
|
|
|
|
|
model = None
|
|
|
|
|
if load_model:
|
|
|
|
|
model = Model.from_pretrained(
|
|
|
|
|
model_dir,
|
|
|
|
|
cfg_dict=config,
|
|
|
|
|
config=model_config,
|
|
|
|
|
device_map='auto',
|
|
|
|
|
torch_dtype=torch_dtype)
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
2023-08-02 09:25:21 +08:00
|
|
|
def get_model_tokenizer_qwen(model_dir: str,
|
|
|
|
|
torch_dtype: Dtype,
|
|
|
|
|
load_model: bool = True):
|
|
|
|
|
config = read_config(model_dir)
|
|
|
|
|
logger.info(config)
|
|
|
|
|
model_config = QWenConfig.from_pretrained(model_dir)
|
|
|
|
|
model_config.torch_dtype = torch_dtype
|
|
|
|
|
logger.info(model_config)
|
|
|
|
|
tokenizer = QWenTokenizer.from_pretrained(model_dir)
|
|
|
|
|
model = None
|
|
|
|
|
if load_model:
|
|
|
|
|
model = Model.from_pretrained(
|
|
|
|
|
model_dir,
|
|
|
|
|
cfg_dict=config,
|
|
|
|
|
config=model_config,
|
|
|
|
|
device_map='auto',
|
|
|
|
|
torch_dtype=torch_dtype)
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
2023-07-26 18:12:55 +08:00
|
|
|
class LoRATM(NamedTuple):
|
|
|
|
|
# default lora target modules
|
|
|
|
|
baichuan = ['W_pack']
|
|
|
|
|
chatglm2 = ['query_key_value']
|
|
|
|
|
llama2 = ['q_proj', 'k_proj', 'v_proj']
|
2023-08-02 09:25:21 +08:00
|
|
|
qwen = ['c_attn']
|
2023-08-10 20:57:47 +08:00
|
|
|
polylm = ['c_attn']
|
2023-07-26 18:12:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Reference: 'https://modelscope.cn/models/{model_id}/summary'
|
2023-07-29 00:06:27 +08:00
|
|
|
# keys: 'model_id', 'revision', 'torch_dtype', 'get_function',
|
|
|
|
|
# 'ignore_file_pattern', 'special_token_mapper', 'lora_TM'
|
2023-07-26 18:12:55 +08:00
|
|
|
MODEL_MAPPER = {
|
|
|
|
|
'baichuan-7b': {
|
2023-07-29 00:06:27 +08:00
|
|
|
'model_id': 'baichuan-inc/baichuan-7B', # model id or model dir
|
2023-07-26 18:12:55 +08:00
|
|
|
'revision': 'v1.0.7',
|
|
|
|
|
'lora_TM': LoRATM.baichuan
|
|
|
|
|
},
|
|
|
|
|
'baichuan-13b': {
|
|
|
|
|
'model_id': 'baichuan-inc/Baichuan-13B-Base',
|
|
|
|
|
'revision': 'v1.0.3',
|
2023-07-29 00:06:27 +08:00
|
|
|
'torch_dtype': torch.bfloat16,
|
2023-07-26 18:12:55 +08:00
|
|
|
'lora_TM': LoRATM.baichuan
|
|
|
|
|
},
|
2023-07-29 00:06:27 +08:00
|
|
|
'chatglm2-6b': {
|
2023-07-26 18:12:55 +08:00
|
|
|
'model_id': 'ZhipuAI/chatglm2-6b',
|
|
|
|
|
'revision': 'v1.0.6',
|
|
|
|
|
'get_function': get_model_tokenizer_chatglm2,
|
|
|
|
|
'lora_TM': LoRATM.chatglm2
|
|
|
|
|
},
|
|
|
|
|
'llama2-7b': {
|
|
|
|
|
'model_id': 'modelscope/Llama-2-7b-ms',
|
|
|
|
|
'revision': 'v1.0.2',
|
|
|
|
|
'ignore_file_pattern': [r'.+\.bin$'], # use safetensors
|
|
|
|
|
'lora_TM': LoRATM.llama2
|
|
|
|
|
},
|
|
|
|
|
'llama2-13b': {
|
|
|
|
|
'model_id': 'modelscope/Llama-2-13b-ms',
|
|
|
|
|
'revision': 'v1.0.2',
|
|
|
|
|
'ignore_file_pattern': [r'.+\.bin$'],
|
|
|
|
|
'lora_TM': LoRATM.llama2
|
|
|
|
|
},
|
|
|
|
|
'openbuddy-llama2-13b': {
|
|
|
|
|
'model_id': 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16',
|
2023-08-02 09:25:21 +08:00
|
|
|
'revision': 'v1.0.0',
|
2023-07-26 18:12:55 +08:00
|
|
|
'lora_TM': LoRATM.llama2
|
2023-08-02 09:25:21 +08:00
|
|
|
},
|
|
|
|
|
'qwen-7b': {
|
|
|
|
|
'model_id': 'QWen/qwen-7b',
|
|
|
|
|
'revision': 'v1.0.0',
|
|
|
|
|
'get_function': get_model_tokenizer_qwen,
|
|
|
|
|
'torch_dtype': torch.bfloat16,
|
|
|
|
|
'lora_TM': LoRATM.qwen,
|
2023-08-10 20:57:47 +08:00
|
|
|
},
|
|
|
|
|
'polylm-13b': {
|
|
|
|
|
'model_id': 'damo/nlp_polylm_13b_text_generation',
|
|
|
|
|
'revision': 'v1.0.3',
|
|
|
|
|
'get_function': get_model_tokenizer_polylm,
|
|
|
|
|
'torch_dtype': torch.bfloat16,
|
|
|
|
|
'lora_TM': LoRATM.polylm
|
2023-07-26 18:12:55 +08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_tokenizer(model_type: str,
|
2023-07-29 00:06:27 +08:00
|
|
|
torch_dtype: Optional[Dtype] = None,
|
|
|
|
|
load_model: bool = True):
|
2023-07-26 18:12:55 +08:00
|
|
|
data = MODEL_MAPPER.get(model_type)
|
|
|
|
|
if data is None:
|
|
|
|
|
raise ValueError(f'model_type: {model_type}')
|
2023-07-29 00:06:27 +08:00
|
|
|
|
2023-07-26 18:12:55 +08:00
|
|
|
model_id = data['model_id']
|
|
|
|
|
get_function = data.get('get_function', get_model_tokenizer_default)
|
|
|
|
|
ignore_file_pattern = data.get('ignore_file_pattern', [])
|
2023-07-29 00:06:27 +08:00
|
|
|
special_token_mapper = data.get('special_token_mapper', {})
|
|
|
|
|
if torch_dtype is None:
|
|
|
|
|
torch_dtype = data.get('torch_dtype', torch.float16)
|
|
|
|
|
|
|
|
|
|
model_dir = model_id
|
|
|
|
|
if not os.path.exists(model_id):
|
|
|
|
|
revision = data.get('revision', 'master')
|
|
|
|
|
model_dir = snapshot_download(
|
|
|
|
|
model_id, revision, ignore_file_pattern=ignore_file_pattern)
|
|
|
|
|
|
|
|
|
|
model, tokenizer = get_function(model_dir, torch_dtype, load_model)
|
|
|
|
|
_add_special_token(tokenizer, special_token_mapper)
|
2023-07-26 18:12:55 +08:00
|
|
|
return model, tokenizer, model_dir
|