Fix/chatglm2 (#384)

This commit is contained in:
Jintao
2023-07-15 09:59:53 +08:00
committed by GitHub
parent 442bdc74a4
commit c6189d68a0
11 changed files with 190 additions and 229 deletions

View File

@@ -3,22 +3,17 @@ from _common import *
from transformers import TextStreamer
device_ids = [0, 1]
logger.info(device_ids)
select_device(device_ids)
# Note: You need to set the value of `CKPT_FPATH`
CKPT_FAPTH = '/path/to/your/iter_xxx.pth'
# ### Loading Model and Tokenizer
# Note: You need to set the value of `CKPT_FPATH`
CKPT_FAPTH = '/path/to/your/xxx.pth'
LORA_TARGET_MODULES = ['query_key_value']
model, tokenizer = get_chatglm2_model_tokenizer()
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = 1
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
model.bfloat16() # Consistent with training
# ### Preparing lora
LORA_TARGET_MODULES = ['query_key_value']
LORA_RANK = 8
LORA_ALPHA = 32
LORA_DROPOUT_P = 0 # Arbitrary value
@@ -36,7 +31,8 @@ _, test_dataset = get_alpaca_en_zh_dataset(None, True)
# ### Inference
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
for d in test_dataset[:5]:
mini_test_dataset = test_dataset.select(range(5))
for d in mini_test_dataset:
output = d['output']
d['output'] = None
input_ids = tokenize_function(d, tokenizer)['input_ids']
@@ -48,9 +44,10 @@ for d in test_dataset[:5]:
max_new_tokens=512,
attention_mask=attention_mask,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
pad_token_id=tokenizer.eos_token_id,
temperature=0.7,
top_k=50,
top_p=0.7,
do_sample=True)
print()
print(f'[LABELS]{output}')