Add lora_inference for baichuan. (#352)

* add lora_inference.py for baichuan

* fix linttest

* fix linttest

---------

Co-authored-by: hemu <hemu.zp@alibaba-inc.com>
This commit is contained in:
Firmament-cyou
2023-07-04 18:39:36 +08:00
committed by GitHub
parent 08c71f1f3d
commit 423e2ce940
2 changed files with 171 additions and 103 deletions

View File

@@ -0,0 +1,28 @@
import os.path as osp
import torch
from modelscope.pipelines import pipeline
from modelscope.swift import Swift
from modelscope.swift.lora import LoRAConfig
from modelscope.utils.constant import Tasks
# 使用源模型 model_id 初始化 pipeline
model_id = 'baichuan-inc/baichuan-7B'
pipe = pipeline(
task=Tasks.text_generation, model=model_id, model_revision='v1.0.2')
# lora 配置replace_modulesrankalpha 需与训练参数相同
lora_config = LoRAConfig(replace_modules=['pack'], rank=32, lora_alpha=32)
# 转 bf16需与训练精度相同
model = pipe.model.bfloat16()
# model 转 lora
Swift.prepare_model(model, lora_config)
# 加载 lora 参数,默认 link 到于 output/model 路径
work_dir = './tmp'
state_dict = torch.load(osp.join(work_dir, 'output/pytorch_model.bin'))
model.load_state_dict(state_dict)
# 使用 lora model 替换 pipeline 中的 model
pipe.model = model
# 使用 pipeline 推理
result_zh = pipe('今天天气是真的')
print(result_zh)