mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 19:19:21 +01:00
Fix/chatglm2 (#384)
This commit is contained in:
@@ -3,22 +3,17 @@ from _common import *
|
||||
from transformers import TextStreamer
|
||||
|
||||
device_ids = [0, 1]
|
||||
logger.info(device_ids)
|
||||
select_device(device_ids)
|
||||
# Note: You need to set the value of `CKPT_FPATH`
|
||||
CKPT_FAPTH = '/path/to/your/iter_xxx.pth'
|
||||
|
||||
# ### Loading Model and Tokenizer
|
||||
# Note: You need to set the value of `CKPT_FPATH`
|
||||
CKPT_FAPTH = '/path/to/your/xxx.pth'
|
||||
LORA_TARGET_MODULES = ['query_key_value']
|
||||
|
||||
model, tokenizer = get_chatglm2_model_tokenizer()
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token_id = tokenizer.pad_token_id
|
||||
if tokenizer.bos_token_id is None:
|
||||
tokenizer.bos_token_id = 1
|
||||
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
|
||||
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
|
||||
model.bfloat16() # Consistent with training
|
||||
|
||||
# ### Preparing lora
|
||||
LORA_TARGET_MODULES = ['query_key_value']
|
||||
LORA_RANK = 8
|
||||
LORA_ALPHA = 32
|
||||
LORA_DROPOUT_P = 0 # Arbitrary value
|
||||
@@ -36,7 +31,8 @@ _, test_dataset = get_alpaca_en_zh_dataset(None, True)
|
||||
|
||||
# ### Inference
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
for d in test_dataset[:5]:
|
||||
mini_test_dataset = test_dataset.select(range(5))
|
||||
for d in mini_test_dataset:
|
||||
output = d['output']
|
||||
d['output'] = None
|
||||
input_ids = tokenize_function(d, tokenizer)['input_ids']
|
||||
@@ -48,9 +44,10 @@ for d in test_dataset[:5]:
|
||||
max_new_tokens=512,
|
||||
attention_mask=attention_mask,
|
||||
streamer=streamer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.7,
|
||||
do_sample=True)
|
||||
print()
|
||||
print(f'[LABELS]{output}')
|
||||
|
||||
Reference in New Issue
Block a user