Fix/chatglm2 (#384)

This commit is contained in:
Jintao
2023-07-15 09:59:53 +08:00
committed by GitHub
parent 442bdc74a4
commit c6189d68a0
11 changed files with 190 additions and 229 deletions

View File

@@ -49,11 +49,9 @@ from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.registry import default_group
#
SYSTEM_TEXT = """{system}"""
USER_TEXT = """\n\n### 用户
{user}"""
ASSISTANT_PROMPT = """\n\n### 助手
"""
PROMPT = """System: {system}
Human: {user}
AI: """
MAX_LENGTH = 2048
TEST_MAX_LENGTH = MAX_LENGTH
@@ -62,11 +60,6 @@ logger = get_logger()
#
def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str:
model_dir = snapshot_download(model_id, model_revision)
return model_dir
def _get_version(work_dir: str) -> int:
if os.path.isdir(work_dir):
fnames = os.listdir(work_dir)
@@ -93,28 +86,40 @@ def get_work_dir(work_dir: str) -> str:
return work_dir
def select_device(device_ids: List[int]) -> Device:
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
if isinstance(device, list):
device_ids = device
device_str = ','.join([str(d) for d in device])
else:
device_ids = [int(d) for d in device.split(',') if d != '-1']
device_str = device
device_str = device_str.replace(' ', '')
return device_ids, device_str
def select_device(device: Union[List[int], str]) -> Device:
"""Call this function before cuda is initialized.
Return: master device
device: e.g. []: 'cpu', [0], [0, 1, 2]
e.g. '-1': 'cpu', '0', '0,1,2'
"""
if torch.cuda.is_initialized():
logger.warning('CUDA has been initialized! Device selection fails!')
return torch.device('cuda:0')
#
device_ids, device_str = _format_device(device)
#
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
log_s = 'Using device: '
if len(device_ids) == 0: # cpu
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device: str = 'cpu'
log_s += device
if len(device_ids) == 0:
master_device: str = 'cpu'
log_s += 'cpu'
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
[str(d) for d in device_ids])
assert torch.cuda.is_available(
) and torch.cuda.device_count() >= len(device_ids)
log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8'
device = 'cuda:0'
master_device = 'cuda:0'
log_s += f'cuda:{device_str}'
logger.info(log_s)
return torch.device(device)
return torch.device(master_device)
def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
@@ -148,37 +153,27 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
def tokenize_function(system: str, user: str, assistant: Optional[str],
tokenizer) -> Dict[str, Any]:
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
system_text = SYSTEM_TEXT.format(system=system)
user_text = USER_TEXT.format(user=user)
system_text_ids: List[int] = tokenizer(
system_text, return_attention_mask=False,
src_text = PROMPT.format(system=system, user=user)
src_input_ids: List[int] = tokenizer(
src_text, return_attention_mask=False,
add_special_tokens=True)['input_ids']
user_text_ids: List[int] = tokenizer(
user_text, return_attention_mask=False,
add_special_tokens=False)['input_ids']
assistant_p_input_ids: List[int] = tokenizer(
ASSISTANT_PROMPT,
return_attention_mask=False,
add_special_tokens=False)['input_ids']
# tokenizer.bos_token_id: Avoid `assistant` being empty
assistant_input_ids: List[int] = [tokenizer.bos_token_id]
#
tgt_input_ids: List[int] = []
if assistant is not None:
assistant_input_ids += tokenizer(
tgt_input_ids += tokenizer(
assistant, return_attention_mask=False,
add_special_tokens=False)['input_ids']
assistant_input_ids += [tokenizer.eos_token_id]
tgt_input_ids += [tokenizer.eos_token_id]
labels = [-100] * len(src_input_ids) + tgt_input_ids
else:
labels = None
input_ids = src_input_ids + tgt_input_ids
#
input_ids = system_text_ids + user_text_ids + assistant_p_input_ids + assistant_input_ids
if assistant is not None: # train, val
if assistant is not None:
if len(input_ids) > MAX_LENGTH:
return {}
len_mask = len(input_ids) - len(assistant_input_ids)
labels = [-100] * len_mask + assistant_input_ids
else: # test
else:
input_ids = input_ids[-TEST_MAX_LENGTH:]
labels = None
#
return {'input_ids': input_ids, 'labels': labels}
@@ -305,12 +300,21 @@ class MyMetric(Metric):
raise NotImplementedError
def get_baichuan_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'baichuan-inc/baichuan-7B'
model_dir = get_model_dir(model_id, None)
#
def _add_special_token(tokenizer):
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = 2
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = 1
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0
logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
f'eos_token_id: {tokenizer.eos_token_id}, '
f'pad_token_id: {tokenizer.pad_token_id}')
def get_baichuan7B_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
sys.path.insert(0, model_dir)
from configuration_baichuan import BaiChuanConfig
from tokenization_baichuan import BaiChuanTokenizer
@@ -327,15 +331,14 @@ def get_baichuan_model_tokenizer(model_dir: Optional[str] = None,
device_map='auto',
torch_dtype=torch.float16)
#
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'ZhipuAI/chatglm2-6b'
model_dir = snapshot_download(model_id, None)
#
def get_chatglm2_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
config = read_config(model_dir)
config['model'] = ConfigDict({'type': 'chatglm2-6b'})
tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
@@ -346,6 +349,8 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
cfg_dict=config,
device_map='auto',
torch_dtype=torch.float16)
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer