diff --git a/examples/pytorch/llm/_common.py b/examples/pytorch/llm/_common.py index 79a958ec..161c99bf 100644 --- a/examples/pytorch/llm/_common.py +++ b/examples/pytorch/llm/_common.py @@ -51,25 +51,15 @@ from modelscope.utils.config import Config, ConfigDict from modelscope.utils.registry import default_group # -TEST_SPLIT_P = 0.01 -SPLIT_SEED = 42 -MAX_LENGTH: Optional[int] = 2048 COLOR, COLOR_S = '#FFE2D9', '#FF7043' -PROMPT = """### 用户 -{instruction} -### AI助手 -""" +PROMPT = """Human: {instruction} +AI: """ logger = get_logger() # -def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str: - model_dir = snapshot_download(model_id, model_revision) - return model_dir - - def _get_version(work_dir: str) -> int: if os.path.isdir(work_dir): fnames = os.listdir(work_dir) @@ -96,28 +86,40 @@ def get_work_dir(work_dir: str) -> str: return work_dir -def select_device(device_ids: List[int]) -> Device: +def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]: + if isinstance(device, list): + device_ids = device + device_str = ','.join([str(d) for d in device]) + else: + device_ids = [int(d) for d in device.split(',') if d != '-1'] + device_str = device + device_str = device_str.replace(' ', '') + return device_ids, device_str + + +def select_device(device: Union[List[int], str]) -> Device: """Call this function before cuda is initialized. - Return: master device + device: e.g. []: 'cpu', [0], [0, 1, 2] + e.g. '-1': 'cpu', '0', '0,1,2' """ if torch.cuda.is_initialized(): logger.warning('CUDA has been initialized! Device selection fails!') return torch.device('cuda:0') # + device_ids, device_str = _format_device(device) + # + os.environ['CUDA_VISIBLE_DEVICES'] = device_str log_s = 'Using device: ' - if len(device_ids) == 0: # cpu - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' - device: str = 'cpu' - log_s += device + if len(device_ids) == 0: + master_device: str = 'cpu' + log_s += 'cpu' else: - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( - [str(d) for d in device_ids]) assert torch.cuda.is_available( ) and torch.cuda.device_count() >= len(device_ids) - log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8' - device = 'cuda:0' + master_device = 'cuda:0' + log_s += f'cuda:{device_str}' logger.info(log_s) - return torch.device(device) + return torch.device(master_device) def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int: @@ -148,7 +150,9 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int, return T_max -def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]: +def tokenize_function(example: Dict[str, str], + tokenizer, + max_length: Optional[int] = 2048) -> Dict[str, Any]: """Only applicable to baichuan and chatglm2. Other models need to be tested""" instruction = example['instruction'] input_: str = example['input'] @@ -159,12 +163,12 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]: else: instruction = instruction + input_ output = example['output'] - src_text = PROMPT.format(instruction=instruction, add_special_tokens=False) + src_text = PROMPT.format(instruction=instruction) src_input_ids: List[int] = tokenizer( src_text, return_attention_mask=False, add_special_tokens=True)['input_ids'] - # tokenizer.bos_token_id: Avoid `tgt_input_ids` being empty - tgt_input_ids = [tokenizer.bos_token_id] + # + tgt_input_ids = [] if output is not None: tgt_input_ids += tokenizer( output, return_attention_mask=False, @@ -175,10 +179,10 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]: labels = None input_ids = src_input_ids + tgt_input_ids # - if MAX_LENGTH is not None: - input_ids = input_ids[-MAX_LENGTH:] + if max_length is not None: + input_ids = input_ids[-max_length:] if labels is not None: - labels = labels[-MAX_LENGTH:] + labels = labels[-max_length:] # return {'input_ids': input_ids, 'labels': labels} @@ -200,8 +204,10 @@ def stat_dataset(dataset: HFDataset) -> None: def print_examples(examples: Dict[str, Any], tokenizer) -> None: input_ids, labels = examples['input_ids'], examples['labels'] - print(f'[INPUT_IDS] {tokenizer.decode(input_ids)}') + print(f'[INPUT_IDS] {input_ids}') + print(f'[INPUT] {tokenizer.decode(input_ids)}') print() + print(f'[LABLES_IDS] {labels}') print( f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}' ) @@ -283,16 +289,25 @@ class MyMetric(Metric): } def merge(self, other: 'MyMetric') -> None: - """This script does not support ddp""" + """This script does not support ddp. TODO""" raise NotImplementedError -def get_baichuan7B_model_tokenizer(model_dir: Optional[str] = None, - load_model: bool = True): - if model_dir is None: - model_id = 'baichuan-inc/baichuan-7B' - model_dir = get_model_dir(model_id, None) - # +def _add_special_token(tokenizer): + if tokenizer.eos_token_id is None: + tokenizer.eos_token_id = 2 + if tokenizer.bos_token_id is None: + tokenizer.bos_token_id = 1 + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = 0 + logger.info(f'bos_token_id: {tokenizer.bos_token_id}, ' + f'eos_token_id: {tokenizer.eos_token_id}, ' + f'pad_token_id: {tokenizer.pad_token_id}') + + +def get_baichuan7B_model_tokenizer(model_dir: str, + load_model: bool = True, + add_special_token: bool = True): sys.path.insert(0, model_dir) from configuration_baichuan import BaiChuanConfig from tokenization_baichuan import BaiChuanTokenizer @@ -309,15 +324,14 @@ def get_baichuan7B_model_tokenizer(model_dir: Optional[str] = None, device_map='auto', torch_dtype=torch.float16) # + if add_special_token: + _add_special_token(tokenizer) return model, tokenizer -def get_baichuan13B_model_tokenizer(model_dir: Optional[str] = None, - load_model: bool = True): - if model_dir is None: - model_id = 'baichuan-inc/Baichuan-13B-Base' - model_dir = get_model_dir(model_id, 'v1.0.1') - # +def get_baichuan13B_model_tokenizer(model_dir: str, + load_model: bool = True, + add_special_token: bool = True): sys.path.insert(0, model_dir) from configuration_baichuan import BaichuanConfig from tokenization_baichuan import BaichuanTokenizer @@ -334,15 +348,14 @@ def get_baichuan13B_model_tokenizer(model_dir: Optional[str] = None, device_map='auto', torch_dtype=torch.float16) # + if add_special_token: + _add_special_token(tokenizer) return model, tokenizer -def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None, - load_model: bool = True): - if model_dir is None: - model_id = 'ZhipuAI/chatglm2-6b' - model_dir = snapshot_download(model_id, None) - # +def get_chatglm2_model_tokenizer(model_dir: str, + load_model: bool = True, + add_special_token: bool = True): config = read_config(model_dir) config['model'] = ConfigDict({'type': 'chatglm2-6b'}) tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir) @@ -353,12 +366,16 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None, cfg_dict=config, device_map='auto', torch_dtype=torch.float16) + if add_special_token: + _add_special_token(tokenizer) return model, tokenizer def get_alpaca_en_zh_dataset( tokenize_function, - only_val: bool = False) -> Tuple[HFDataset, HFDataset]: + only_val: bool = False, + test_split_p: float = 0.01, + split_seed: int = 42) -> Tuple[HFDataset, HFDataset]: """ split: Literal['train', 'validation', None] """ @@ -371,7 +388,7 @@ def get_alpaca_en_zh_dataset( dataset: HFDataset = concatenate_datasets([dataset_zh, dataset_en]) # # dataset = dataset.select(range(1000)) # for debug - dataset = dataset.train_test_split(TEST_SPLIT_P, seed=SPLIT_SEED) + dataset = dataset.train_test_split(test_split_p, seed=split_seed) if only_val: dataset = dataset['test'] if tokenize_function is not None: diff --git a/examples/pytorch/llm/baichuan_infer.py b/examples/pytorch/llm/baichuan_infer.py index f9a49c09..6e027347 100644 --- a/examples/pytorch/llm/baichuan_infer.py +++ b/examples/pytorch/llm/baichuan_infer.py @@ -3,24 +3,22 @@ from _common import * from transformers import TextStreamer device_ids = [0, 1] -logger.info(device_ids) select_device(device_ids) +# Note: You need to set the value of `CKPT_FPATH` +CKPT_FAPTH = '/path/to/your/iter_xxx.pth' # ### Loading Model and Tokenizer -# Note: You need to set the value of `CKPT_FPATH` BAICHUAN_TYPE = '13B' # Literal['7B', '13B'] -CKPT_FAPTH = '/path/to/your/xxx.pth' -LORA_TARGET_MODULES = ['W_pack'] - if BAICHUAN_TYPE == '7B': - model, tokenizer = get_baichuan7B_model_tokenizer() + model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5') + model, tokenizer = get_baichuan7B_model_tokenizer(model_dir) else: - model, tokenizer = get_baichuan13B_model_tokenizer() -if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id + model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2') + model, tokenizer = get_baichuan13B_model_tokenizer(model_dir) model.bfloat16() # Consistent with training # ### Preparing lora +LORA_TARGET_MODULES = ['W_pack'] LORA_RANK = 8 LORA_ALPHA = 32 LORA_DROPOUT_P = 0 # Arbitrary value @@ -38,7 +36,8 @@ _, test_dataset = get_alpaca_en_zh_dataset(None, True) # ### Inference streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) -for d in test_dataset[:5]: +mini_test_dataset = test_dataset.select(range(5)) +for d in mini_test_dataset: output = d['output'] d['output'] = None input_ids = tokenize_function(d, tokenizer)['input_ids'] @@ -50,9 +49,10 @@ for d in test_dataset[:5]: max_new_tokens=512, attention_mask=attention_mask, streamer=streamer, - pad_token_id=tokenizer.pad_token_id, + pad_token_id=tokenizer.eos_token_id, temperature=0.7, top_k=50, + top_p=0.7, do_sample=True) print() print(f'[LABELS]{output}') diff --git a/examples/pytorch/llm/baichuan_sft.py b/examples/pytorch/llm/baichuan_sft.py index 18f71d22..4addc8b5 100644 --- a/examples/pytorch/llm/baichuan_sft.py +++ b/examples/pytorch/llm/baichuan_sft.py @@ -3,35 +3,27 @@ pip install modelscope pip install numpy pandas matplotlib scikit-learn pip install transformers datasets -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -pip install tqdm -pip install tensorboard -pip install torchmetrics -pip install sentencepiece -pip install accelerate +conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia +pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate pip install numpy -U # Resolve torchmetrics dependencies and update numpy """ from _common import * -device_ids = [0, 1, 2, 3] -logger.info(device_ids) +device_ids = [0, 1] select_device(device_ids) seed_everything(42) # ### Loading Model and Tokenizer BAICHUAN_TYPE = '13B' # Literal['7B', '13B'] WORK_DIR = f'runs/baichuan_{BAICHUAN_TYPE}' -LORA_TARGET_MODULES = ['W_pack'] # if BAICHUAN_TYPE == '7B': - model_id = 'baichuan-inc/baichuan-7B' - model_dir = get_model_dir(model_id, None) + model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5') model, tokenizer = get_baichuan7B_model_tokenizer(model_dir) else: - model_id = 'baichuan-inc/Baichuan-13B-Base' - model_dir = get_model_dir(model_id, 'v1.0.1') + model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2') model, tokenizer = get_baichuan13B_model_tokenizer(model_dir) # GRADIENT_CHECKPOINTING = True @@ -46,14 +38,9 @@ if GRADIENT_CHECKPOINTING: model) model.gradient_checkpointing_enable() model.enable_input_require_grads() -if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id -# -logger.info( - f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, ' - f'pad_token_id: {tokenizer.pad_token_id}') # ### Preparing lora +LORA_TARGET_MODULES = ['W_pack'] LORA_RANK = 8 LORA_ALPHA = 32 LORA_DROPOUT_P = 0.1 diff --git a/examples/pytorch/llm/chatglm2_infer.py b/examples/pytorch/llm/chatglm2_infer.py index 741f9b18..c1a544cb 100644 --- a/examples/pytorch/llm/chatglm2_infer.py +++ b/examples/pytorch/llm/chatglm2_infer.py @@ -3,22 +3,17 @@ from _common import * from transformers import TextStreamer device_ids = [0, 1] -logger.info(device_ids) select_device(device_ids) +# Note: You need to set the value of `CKPT_FPATH` +CKPT_FAPTH = '/path/to/your/iter_xxx.pth' # ### Loading Model and Tokenizer -# Note: You need to set the value of `CKPT_FPATH` -CKPT_FAPTH = '/path/to/your/xxx.pth' -LORA_TARGET_MODULES = ['query_key_value'] - -model, tokenizer = get_chatglm2_model_tokenizer() -if tokenizer.eos_token_id is None: - tokenizer.eos_token_id = tokenizer.pad_token_id -if tokenizer.bos_token_id is None: - tokenizer.bos_token_id = 1 +model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6') +model, tokenizer = get_chatglm2_model_tokenizer(model_dir) model.bfloat16() # Consistent with training # ### Preparing lora +LORA_TARGET_MODULES = ['query_key_value'] LORA_RANK = 8 LORA_ALPHA = 32 LORA_DROPOUT_P = 0 # Arbitrary value @@ -36,7 +31,8 @@ _, test_dataset = get_alpaca_en_zh_dataset(None, True) # ### Inference streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) -for d in test_dataset[:5]: +mini_test_dataset = test_dataset.select(range(5)) +for d in mini_test_dataset: output = d['output'] d['output'] = None input_ids = tokenize_function(d, tokenizer)['input_ids'] @@ -48,9 +44,10 @@ for d in test_dataset[:5]: max_new_tokens=512, attention_mask=attention_mask, streamer=streamer, - pad_token_id=tokenizer.pad_token_id, + pad_token_id=tokenizer.eos_token_id, temperature=0.7, top_k=50, + top_p=0.7, do_sample=True) print() print(f'[LABELS]{output}') diff --git a/examples/pytorch/llm/chatglm2_sft.py b/examples/pytorch/llm/chatglm2_sft.py index ecd497a2..4876025b 100644 --- a/examples/pytorch/llm/chatglm2_sft.py +++ b/examples/pytorch/llm/chatglm2_sft.py @@ -3,46 +3,31 @@ pip install modelscope pip install numpy pandas matplotlib scikit-learn pip install transformers datasets -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -pip install tqdm -pip install tensorboard -pip install torchmetrics -pip install sentencepiece -pip install accelerate +conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia +pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate pip install numpy -U # Resolve torchmetrics dependencies and update numpy """ from _common import * -device_ids = [0, 1, 2, 3] -logger.info(device_ids) +device_ids = [0, 1] select_device(device_ids) seed_everything(42) # ### Loading Model and Tokenizer -model_id = 'ZhipuAI/chatglm2-6b' WORK_DIR = 'runs/chatglm2' -LORA_TARGET_MODULES = ['query_key_value'] # -model_dir = get_model_dir(model_id, None) +model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6') model, tokenizer = get_chatglm2_model_tokenizer(model_dir) -# chatglm2 does not support gradient_checkpointing -GRADIENT_CHECKPOINTING = False +# +GRADIENT_CHECKPOINTING = True if GRADIENT_CHECKPOINTING: model.gradient_checkpointing_enable() model.enable_input_require_grads() -logger.info(tokenizer.special_tokens) -if tokenizer.eos_token_id is None: - tokenizer.eos_token_id = tokenizer.pad_token_id -if tokenizer.bos_token_id is None: - tokenizer.bos_token_id = 1 -# -logger.info( - f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, ' - f'pad_token_id: {tokenizer.pad_token_id}') # ### Preparing lora +LORA_TARGET_MODULES = ['query_key_value'] LORA_RANK = 8 LORA_ALPHA = 32 LORA_DROPOUT_P = 0.1 diff --git a/examples/pytorch/llm_agent/_common.py b/examples/pytorch/llm_agent/_common.py index 04097b50..dd07ef31 100644 --- a/examples/pytorch/llm_agent/_common.py +++ b/examples/pytorch/llm_agent/_common.py @@ -49,11 +49,9 @@ from modelscope.utils.config import Config, ConfigDict from modelscope.utils.registry import default_group # -SYSTEM_TEXT = """{system}""" -USER_TEXT = """\n\n### 用户 -{user}""" -ASSISTANT_PROMPT = """\n\n### 助手 -""" +PROMPT = """System: {system} +Human: {user} +AI: """ MAX_LENGTH = 2048 TEST_MAX_LENGTH = MAX_LENGTH @@ -62,11 +60,6 @@ logger = get_logger() # -def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str: - model_dir = snapshot_download(model_id, model_revision) - return model_dir - - def _get_version(work_dir: str) -> int: if os.path.isdir(work_dir): fnames = os.listdir(work_dir) @@ -93,28 +86,40 @@ def get_work_dir(work_dir: str) -> str: return work_dir -def select_device(device_ids: List[int]) -> Device: +def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]: + if isinstance(device, list): + device_ids = device + device_str = ','.join([str(d) for d in device]) + else: + device_ids = [int(d) for d in device.split(',') if d != '-1'] + device_str = device + device_str = device_str.replace(' ', '') + return device_ids, device_str + + +def select_device(device: Union[List[int], str]) -> Device: """Call this function before cuda is initialized. - Return: master device + device: e.g. []: 'cpu', [0], [0, 1, 2] + e.g. '-1': 'cpu', '0', '0,1,2' """ if torch.cuda.is_initialized(): logger.warning('CUDA has been initialized! Device selection fails!') return torch.device('cuda:0') # + device_ids, device_str = _format_device(device) + # + os.environ['CUDA_VISIBLE_DEVICES'] = device_str log_s = 'Using device: ' - if len(device_ids) == 0: # cpu - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' - device: str = 'cpu' - log_s += device + if len(device_ids) == 0: + master_device: str = 'cpu' + log_s += 'cpu' else: - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( - [str(d) for d in device_ids]) assert torch.cuda.is_available( ) and torch.cuda.device_count() >= len(device_ids) - log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8' - device = 'cuda:0' + master_device = 'cuda:0' + log_s += f'cuda:{device_str}' logger.info(log_s) - return torch.device(device) + return torch.device(master_device) def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int: @@ -148,37 +153,27 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int, def tokenize_function(system: str, user: str, assistant: Optional[str], tokenizer) -> Dict[str, Any]: """Only applicable to baichuan and chatglm2. Other models need to be tested""" - system_text = SYSTEM_TEXT.format(system=system) - user_text = USER_TEXT.format(user=user) - system_text_ids: List[int] = tokenizer( - system_text, return_attention_mask=False, + src_text = PROMPT.format(system=system, user=user) + src_input_ids: List[int] = tokenizer( + src_text, return_attention_mask=False, add_special_tokens=True)['input_ids'] - user_text_ids: List[int] = tokenizer( - user_text, return_attention_mask=False, - add_special_tokens=False)['input_ids'] - assistant_p_input_ids: List[int] = tokenizer( - ASSISTANT_PROMPT, - return_attention_mask=False, - add_special_tokens=False)['input_ids'] - - # tokenizer.bos_token_id: Avoid `assistant` being empty - assistant_input_ids: List[int] = [tokenizer.bos_token_id] + # + tgt_input_ids: List[int] = [] if assistant is not None: - assistant_input_ids += tokenizer( + tgt_input_ids += tokenizer( assistant, return_attention_mask=False, add_special_tokens=False)['input_ids'] - assistant_input_ids += [tokenizer.eos_token_id] + tgt_input_ids += [tokenizer.eos_token_id] + labels = [-100] * len(src_input_ids) + tgt_input_ids + else: + labels = None + input_ids = src_input_ids + tgt_input_ids # - input_ids = system_text_ids + user_text_ids + assistant_p_input_ids + assistant_input_ids - if assistant is not None: # train, val + if assistant is not None: if len(input_ids) > MAX_LENGTH: return {} - len_mask = len(input_ids) - len(assistant_input_ids) - labels = [-100] * len_mask + assistant_input_ids - else: # test + else: input_ids = input_ids[-TEST_MAX_LENGTH:] - labels = None - # return {'input_ids': input_ids, 'labels': labels} @@ -305,12 +300,21 @@ class MyMetric(Metric): raise NotImplementedError -def get_baichuan_model_tokenizer(model_dir: Optional[str] = None, - load_model: bool = True): - if model_dir is None: - model_id = 'baichuan-inc/baichuan-7B' - model_dir = get_model_dir(model_id, None) - # +def _add_special_token(tokenizer): + if tokenizer.eos_token_id is None: + tokenizer.eos_token_id = 2 + if tokenizer.bos_token_id is None: + tokenizer.bos_token_id = 1 + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = 0 + logger.info(f'bos_token_id: {tokenizer.bos_token_id}, ' + f'eos_token_id: {tokenizer.eos_token_id}, ' + f'pad_token_id: {tokenizer.pad_token_id}') + + +def get_baichuan7B_model_tokenizer(model_dir: str, + load_model: bool = True, + add_special_token: bool = True): sys.path.insert(0, model_dir) from configuration_baichuan import BaiChuanConfig from tokenization_baichuan import BaiChuanTokenizer @@ -327,15 +331,14 @@ def get_baichuan_model_tokenizer(model_dir: Optional[str] = None, device_map='auto', torch_dtype=torch.float16) # + if add_special_token: + _add_special_token(tokenizer) return model, tokenizer -def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None, - load_model: bool = True): - if model_dir is None: - model_id = 'ZhipuAI/chatglm2-6b' - model_dir = snapshot_download(model_id, None) - # +def get_chatglm2_model_tokenizer(model_dir: str, + load_model: bool = True, + add_special_token: bool = True): config = read_config(model_dir) config['model'] = ConfigDict({'type': 'chatglm2-6b'}) tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir) @@ -346,6 +349,8 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None, cfg_dict=config, device_map='auto', torch_dtype=torch.float16) + if add_special_token: + _add_special_token(tokenizer) return model, tokenizer diff --git a/examples/pytorch/llm_agent/baichuan_infer.ipynb b/examples/pytorch/llm_agent/baichuan_infer.ipynb index 03f8f46b..7ef29951 100644 --- a/examples/pytorch/llm_agent/baichuan_infer.ipynb +++ b/examples/pytorch/llm_agent/baichuan_infer.ipynb @@ -54,7 +54,6 @@ "from _common import *\n", "from transformers import TextStreamer\n", "device_ids = [0, 1]\n", - "logger.info(device_ids)\n", "select_device(device_ids)" ] }, @@ -146,9 +145,8 @@ "CKPT_FAPTH = '/home/hackathon/my_git/agent/runs/baichuan/v10-20230702-172449/output_best/pytorch_model.bin'\n", "LORA_TARGET_MODULES = ['W_pack']\n", "\n", - "model, tokenizer = get_baichuan_model_tokenizer()\n", - "if tokenizer.pad_token_id is None:\n", - " tokenizer.pad_token_id = tokenizer.eos_token_id\n", + "model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')\n", + "model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)\n", "model.bfloat16() # Consistent with training" ] }, @@ -451,8 +449,8 @@ " attention_mask = torch.ones_like(input_ids)\n", " generate_ids = model.generate(input_ids=input_ids, max_new_tokens=512,\n", " attention_mask=attention_mask,\n", - " streamer=streamer, pad_token_id=tokenizer.pad_token_id, \n", - " temperature=0.7, top_k=50, do_sample=True)\n", + " streamer=streamer, pad_token_id=tokenizer.eos_token_id, \n", + " temperature=0.7, top_k=50, top_p=0.7, do_sample=True)\n", " print()\n", " print(f'[LABELS]{assistant}')\n", " print('-----------------------------------------------------------------------------------')\n", diff --git a/examples/pytorch/llm_agent/baichuan_sft.ipynb b/examples/pytorch/llm_agent/baichuan_sft.ipynb index cb732612..75a9240e 100644 --- a/examples/pytorch/llm_agent/baichuan_sft.ipynb +++ b/examples/pytorch/llm_agent/baichuan_sft.ipynb @@ -33,16 +33,12 @@ "metadata": {}, "outputs": [], "source": [ - "# !pip install modelscope -U\n", + "# !pip install modelscope\n", "# !pip install numpy pandas matplotlib scikit-learn\n", "# !pip install transformers datasets\n", - "# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n", - "# !pip install tqdm\n", - "# !pip install tensorboard\n", - "# !pip install torchmetrics\n", - "# !pip install sentencepiece\n", - "# !pip install accelerate\n", - "#\n", + "# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia\n", + "# !pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate\n", + "\n", "# !pip install numpy -U # Resolve torchmetrics dependencies and update numpy" ] }, @@ -75,8 +71,7 @@ ], "source": [ "from _common import *\n", - "device_ids = [0, 1, 2, 3]\n", - "logger.info(device_ids)\n", + "device_ids = [0, 1]\n", "select_device(device_ids)\n", "_ = seed_everything(42)" ] @@ -132,22 +127,16 @@ } ], "source": [ - "model_id = 'baichuan-inc/baichuan-7B'\n", "WORK_DIR = 'runs/baichuan'\n", "LORA_TARGET_MODULES = ['W_pack']\n", "#\n", - "model_dir = get_model_dir(model_id, None)\n", - "model, tokenizer = get_baichuan_model_tokenizer(model_dir)\n", + "model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')\n", + "model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)\n", "#\n", "GRADIENT_CHECKPOINTING = True\n", "if GRADIENT_CHECKPOINTING:\n", " model.gradient_checkpointing_enable()\n", - " model.enable_input_require_grads()\n", - "if tokenizer.pad_token_id is None:\n", - " tokenizer.pad_token_id = tokenizer.eos_token_id\n", - "#\n", - "logger.info(f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, '\n", - " f'pad_token_id: {tokenizer.pad_token_id}')" + " model.enable_input_require_grads()" ] }, { diff --git a/examples/pytorch/llm_agent/chatglm2_infer.ipynb b/examples/pytorch/llm_agent/chatglm2_infer.ipynb index 237d27c8..821da5e6 100644 --- a/examples/pytorch/llm_agent/chatglm2_infer.ipynb +++ b/examples/pytorch/llm_agent/chatglm2_infer.ipynb @@ -55,7 +55,6 @@ "from _common import *\n", "from transformers import TextStreamer\n", "device_ids = [0, 1]\n", - "logger.info(device_ids)\n", "select_device(device_ids)" ] }, @@ -143,11 +142,8 @@ "CKPT_FAPTH = '/home/hackathon/my_git/agent/runs/chatglm2/v1-20230702-203505/output_best/pytorch_model.bin'\n", "LORA_TARGET_MODULES = ['query_key_value']\n", "\n", - "model, tokenizer = get_chatglm2_model_tokenizer()\n", - "if tokenizer.eos_token_id is None:\n", - " tokenizer.eos_token_id = tokenizer.pad_token_id\n", - "if tokenizer.bos_token_id is None:\n", - " tokenizer.bos_token_id = 1\n", + "model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')\n", + "model, tokenizer = get_chatglm2_model_tokenizer(model_dir)\n", "model.bfloat16() # Consistent with training" ] }, @@ -484,8 +480,8 @@ " attention_mask = torch.ones_like(input_ids)\n", " generate_ids = model.generate(input_ids=input_ids, max_new_tokens=512,\n", " attention_mask=attention_mask,\n", - " streamer=streamer, pad_token_id=tokenizer.pad_token_id, \n", - " temperature=0.7, top_k=50, do_sample=True)\n", + " streamer=streamer, pad_token_id=tokenizer.eos_token_id, \n", + " temperature=0.7, top_k=50, top_p=0.7, do_sample=True)\n", " print()\n", " print(f'[LABELS]{assistant}')\n", " print('-----------------------------------------------------------------------------------')\n", diff --git a/examples/pytorch/llm_agent/chatglm2_sft.ipynb b/examples/pytorch/llm_agent/chatglm2_sft.ipynb index 70d9b8a1..4810e4b9 100644 --- a/examples/pytorch/llm_agent/chatglm2_sft.ipynb +++ b/examples/pytorch/llm_agent/chatglm2_sft.ipynb @@ -40,22 +40,18 @@ "metadata": {}, "outputs": [], "source": [ - "# !pip install modelscope -U\n", + "# !pip install modelscope\n", "# !pip install numpy pandas matplotlib scikit-learn\n", "# !pip install transformers datasets\n", - "# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n", - "# !pip install tqdm\n", - "# !pip install tensorboard\n", - "# !pip install torchmetrics\n", - "# !pip install sentencepiece\n", - "# !pip install accelerate\n", - "#\n", + "# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia\n", + "# !pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate\n", + "\n", "# !pip install numpy -U # Resolve torchmetrics dependencies and update numpy" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -80,8 +76,7 @@ ], "source": [ "from _common import *\n", - "device_ids = [0, 1, 2, 3]\n", - "logger.info(device_ids)\n", + "device_ids = [0, 1]\n", "select_device(device_ids)\n", "_ = seed_everything(42)" ] @@ -136,25 +131,16 @@ } ], "source": [ - "model_id = 'ZhipuAI/chatglm2-6b'\n", "WORK_DIR = 'runs/chatglm2'\n", "LORA_TARGET_MODULES = ['query_key_value']\n", "#\n", - "model_dir = get_model_dir(model_id, None)\n", + "model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')\n", "model, tokenizer = get_chatglm2_model_tokenizer(model_dir)\n", - "# chatglm2 does not support gradient_checkpointing\n", - "GRADIENT_CHECKPOINTING = False\n", + "#\n", + "GRADIENT_CHECKPOINTING = True\n", "if GRADIENT_CHECKPOINTING:\n", " model.gradient_checkpointing_enable()\n", - " model.enable_input_require_grads()\n", - "logger.info(tokenizer.special_tokens)\n", - "if tokenizer.eos_token_id is None:\n", - " tokenizer.eos_token_id = tokenizer.pad_token_id\n", - "if tokenizer.bos_token_id is None:\n", - " tokenizer.bos_token_id = 1\n", - "#\n", - "logger.info(f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, '\n", - " f'pad_token_id: {tokenizer.pad_token_id}')" + " model.enable_input_require_grads()" ] }, { diff --git a/modelscope/models/nlp/chatglm2/text_generation.py b/modelscope/models/nlp/chatglm2/text_generation.py index aed855cb..082e16e7 100644 --- a/modelscope/models/nlp/chatglm2/text_generation.py +++ b/modelscope/models/nlp/chatglm2/text_generation.py @@ -1095,6 +1095,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=-100) + shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))