Files
modelscope/examples/pytorch/baichuan/lora_inference.py
Firmament-cyou 423e2ce940 Add lora_inference for baichuan. (#352)
* add lora_inference.py for baichuan

* fix linttest

* fix linttest

---------

Co-authored-by: hemu <hemu.zp@alibaba-inc.com>
2023-07-04 18:39:36 +08:00

29 lines
989 B
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os.path as osp
import torch
from modelscope.pipelines import pipeline
from modelscope.swift import Swift
from modelscope.swift.lora import LoRAConfig
from modelscope.utils.constant import Tasks
# 使用源模型 model_id 初始化 pipeline
model_id = 'baichuan-inc/baichuan-7B'
pipe = pipeline(
task=Tasks.text_generation, model=model_id, model_revision='v1.0.2')
# lora 配置replace_modulesrankalpha 需与训练参数相同
lora_config = LoRAConfig(replace_modules=['pack'], rank=32, lora_alpha=32)
# 转 bf16需与训练精度相同
model = pipe.model.bfloat16()
# model 转 lora
Swift.prepare_model(model, lora_config)
# 加载 lora 参数,默认 link 到于 output/model 路径
work_dir = './tmp'
state_dict = torch.load(osp.join(work_dir, 'output/pytorch_model.bin'))
model.load_state_dict(state_dict)
# 使用 lora model 替换 pipeline 中的 model
pipe.model = model
# 使用 pipeline 推理
result_zh = pipe('今天天气是真的')
print(result_zh)