Fix/chatglm2 (#384)

This commit is contained in:
Jintao
2023-07-15 09:59:53 +08:00
committed by GitHub
parent 442bdc74a4
commit c6189d68a0
11 changed files with 190 additions and 229 deletions

View File

@@ -51,25 +51,15 @@ from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.registry import default_group
#
TEST_SPLIT_P = 0.01
SPLIT_SEED = 42
MAX_LENGTH: Optional[int] = 2048
COLOR, COLOR_S = '#FFE2D9', '#FF7043'
PROMPT = """### 用户
{instruction}
### AI助手
"""
PROMPT = """Human: {instruction}
AI: """
logger = get_logger()
#
def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str:
model_dir = snapshot_download(model_id, model_revision)
return model_dir
def _get_version(work_dir: str) -> int:
if os.path.isdir(work_dir):
fnames = os.listdir(work_dir)
@@ -96,28 +86,40 @@ def get_work_dir(work_dir: str) -> str:
return work_dir
def select_device(device_ids: List[int]) -> Device:
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
if isinstance(device, list):
device_ids = device
device_str = ','.join([str(d) for d in device])
else:
device_ids = [int(d) for d in device.split(',') if d != '-1']
device_str = device
device_str = device_str.replace(' ', '')
return device_ids, device_str
def select_device(device: Union[List[int], str]) -> Device:
"""Call this function before cuda is initialized.
Return: master device
device: e.g. []: 'cpu', [0], [0, 1, 2]
e.g. '-1': 'cpu', '0', '0,1,2'
"""
if torch.cuda.is_initialized():
logger.warning('CUDA has been initialized! Device selection fails!')
return torch.device('cuda:0')
#
device_ids, device_str = _format_device(device)
#
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
log_s = 'Using device: '
if len(device_ids) == 0: # cpu
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device: str = 'cpu'
log_s += device
if len(device_ids) == 0:
master_device: str = 'cpu'
log_s += 'cpu'
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
[str(d) for d in device_ids])
assert torch.cuda.is_available(
) and torch.cuda.device_count() >= len(device_ids)
log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8'
device = 'cuda:0'
master_device = 'cuda:0'
log_s += f'cuda:{device_str}'
logger.info(log_s)
return torch.device(device)
return torch.device(master_device)
def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
@@ -148,7 +150,9 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
return T_max
def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
def tokenize_function(example: Dict[str, str],
tokenizer,
max_length: Optional[int] = 2048) -> Dict[str, Any]:
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
instruction = example['instruction']
input_: str = example['input']
@@ -159,12 +163,12 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
else:
instruction = instruction + input_
output = example['output']
src_text = PROMPT.format(instruction=instruction, add_special_tokens=False)
src_text = PROMPT.format(instruction=instruction)
src_input_ids: List[int] = tokenizer(
src_text, return_attention_mask=False,
add_special_tokens=True)['input_ids']
# tokenizer.bos_token_id: Avoid `tgt_input_ids` being empty
tgt_input_ids = [tokenizer.bos_token_id]
#
tgt_input_ids = []
if output is not None:
tgt_input_ids += tokenizer(
output, return_attention_mask=False,
@@ -175,10 +179,10 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
labels = None
input_ids = src_input_ids + tgt_input_ids
#
if MAX_LENGTH is not None:
input_ids = input_ids[-MAX_LENGTH:]
if max_length is not None:
input_ids = input_ids[-max_length:]
if labels is not None:
labels = labels[-MAX_LENGTH:]
labels = labels[-max_length:]
#
return {'input_ids': input_ids, 'labels': labels}
@@ -200,8 +204,10 @@ def stat_dataset(dataset: HFDataset) -> None:
def print_examples(examples: Dict[str, Any], tokenizer) -> None:
input_ids, labels = examples['input_ids'], examples['labels']
print(f'[INPUT_IDS] {tokenizer.decode(input_ids)}')
print(f'[INPUT_IDS] {input_ids}')
print(f'[INPUT] {tokenizer.decode(input_ids)}')
print()
print(f'[LABLES_IDS] {labels}')
print(
f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}'
)
@@ -283,16 +289,25 @@ class MyMetric(Metric):
}
def merge(self, other: 'MyMetric') -> None:
"""This script does not support ddp"""
"""This script does not support ddp. TODO"""
raise NotImplementedError
def get_baichuan7B_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'baichuan-inc/baichuan-7B'
model_dir = get_model_dir(model_id, None)
#
def _add_special_token(tokenizer):
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = 2
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = 1
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0
logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
f'eos_token_id: {tokenizer.eos_token_id}, '
f'pad_token_id: {tokenizer.pad_token_id}')
def get_baichuan7B_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
sys.path.insert(0, model_dir)
from configuration_baichuan import BaiChuanConfig
from tokenization_baichuan import BaiChuanTokenizer
@@ -309,15 +324,14 @@ def get_baichuan7B_model_tokenizer(model_dir: Optional[str] = None,
device_map='auto',
torch_dtype=torch.float16)
#
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_baichuan13B_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'baichuan-inc/Baichuan-13B-Base'
model_dir = get_model_dir(model_id, 'v1.0.1')
#
def get_baichuan13B_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
sys.path.insert(0, model_dir)
from configuration_baichuan import BaichuanConfig
from tokenization_baichuan import BaichuanTokenizer
@@ -334,15 +348,14 @@ def get_baichuan13B_model_tokenizer(model_dir: Optional[str] = None,
device_map='auto',
torch_dtype=torch.float16)
#
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'ZhipuAI/chatglm2-6b'
model_dir = snapshot_download(model_id, None)
#
def get_chatglm2_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
config = read_config(model_dir)
config['model'] = ConfigDict({'type': 'chatglm2-6b'})
tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
@@ -353,12 +366,16 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
cfg_dict=config,
device_map='auto',
torch_dtype=torch.float16)
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_alpaca_en_zh_dataset(
tokenize_function,
only_val: bool = False) -> Tuple[HFDataset, HFDataset]:
only_val: bool = False,
test_split_p: float = 0.01,
split_seed: int = 42) -> Tuple[HFDataset, HFDataset]:
"""
split: Literal['train', 'validation', None]
"""
@@ -371,7 +388,7 @@ def get_alpaca_en_zh_dataset(
dataset: HFDataset = concatenate_datasets([dataset_zh, dataset_en])
#
# dataset = dataset.select(range(1000)) # for debug
dataset = dataset.train_test_split(TEST_SPLIT_P, seed=SPLIT_SEED)
dataset = dataset.train_test_split(test_split_p, seed=split_seed)
if only_val:
dataset = dataset['test']
if tokenize_function is not None: