mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
[to #42322933] Add polylm lora trainer to modelscope
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13509769 * [to #42322933] Add polylm lora trainer to modelscope
This commit is contained in:
committed by
wenmeng.zwm
parent
33605de759
commit
5e83523c9a
@@ -93,6 +93,7 @@ def llm_infer(args: InferArguments) -> None:
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
do_sample=True,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id)
|
||||
logger.info(f'generation_config: {generation_config}')
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
CUDA_VISIBLE_DEVICES=0,1 \
|
||||
python llm_infer.py \
|
||||
--model_type qwen-7b \
|
||||
--ckpt_path "runs/qwen-7b/vx_xxx/output_best/pytorch_model.bin" \
|
||||
--model_type polylm-13b \
|
||||
--ckpt_path "runs/polylm-13b/v0-20230802-172425/output_best/pytorch_model.bin" \
|
||||
--eval_human true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
CUDA_VISIBLE_DEVICES=0,1 \
|
||||
python llm_sft.py \
|
||||
--model_type qwen-7b \
|
||||
--model_type polylm-13b \
|
||||
--output_dir runs \
|
||||
--dataset alpaca-en,alpaca-zh \
|
||||
--dataset_sample 20000
|
||||
|
||||
@@ -42,6 +42,26 @@ def get_model_tokenizer_default(model_dir: str,
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_model_tokenizer_polylm(model_dir: str,
|
||||
torch_dtype: Dtype,
|
||||
load_model: bool = True):
|
||||
"""load from an independent repository"""
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_dir, trust_remote_code=True)
|
||||
model_config.torch_dtype = torch_dtype
|
||||
logger.info(f'model_config: {model_config}')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
|
||||
model = None
|
||||
if load_model:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
config=model_config,
|
||||
device_map='auto',
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_model_tokenizer_chatglm2(model_dir: str,
|
||||
torch_dtype: Dtype,
|
||||
load_model: bool = True):
|
||||
@@ -89,6 +109,7 @@ class LoRATM(NamedTuple):
|
||||
chatglm2 = ['query_key_value']
|
||||
llama2 = ['q_proj', 'k_proj', 'v_proj']
|
||||
qwen = ['c_attn']
|
||||
polylm = ['c_attn']
|
||||
|
||||
|
||||
# Reference: 'https://modelscope.cn/models/{model_id}/summary'
|
||||
@@ -135,6 +156,13 @@ MODEL_MAPPER = {
|
||||
'get_function': get_model_tokenizer_qwen,
|
||||
'torch_dtype': torch.bfloat16,
|
||||
'lora_TM': LoRATM.qwen,
|
||||
},
|
||||
'polylm-13b': {
|
||||
'model_id': 'damo/nlp_polylm_13b_text_generation',
|
||||
'revision': 'v1.0.3',
|
||||
'get_function': get_model_tokenizer_polylm,
|
||||
'torch_dtype': torch.bfloat16,
|
||||
'lora_TM': LoRATM.polylm
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user