mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
fix baichuan eval and support sequence_length
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13404289 * fix baichuan eval * support sequence_length and ppl * fix typo * fix bug for palm * fix bug
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from modelscope import (EpochBasedTrainer, MsDataset, TrainingArgs,
|
||||
from modelscope import (EpochBasedTrainer, MsDataset, TorchModel, TrainingArgs,
|
||||
build_dataset_from_file, snapshot_download)
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
|
||||
@@ -40,11 +41,11 @@ class TextGenerationArguments(TrainingArgs):
|
||||
'cfg_node': 'preprocessor.tgt_txt'
|
||||
})
|
||||
|
||||
preprocessor: str = field(
|
||||
sequence_length: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The preprocessor type',
|
||||
'cfg_node': 'preprocessor.type'
|
||||
'help': 'The sequence length of preprocessor',
|
||||
'cfg_node': 'preprocessor.sequence_length'
|
||||
})
|
||||
|
||||
lr_scheduler: str = field(
|
||||
@@ -54,25 +55,6 @@ class TextGenerationArguments(TrainingArgs):
|
||||
'cfg_node': 'train.lr_scheduler.type'
|
||||
})
|
||||
|
||||
world_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The parallel world size',
|
||||
'cfg_node': 'megatron.world_size'
|
||||
})
|
||||
|
||||
tensor_model_parallel_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The tensor model parallel size',
|
||||
'cfg_node': 'megatron.tensor_model_parallel_size'
|
||||
})
|
||||
|
||||
use_megatron: bool = field(
|
||||
default=None, metadata={
|
||||
'help': 'Whether to use MegatronHook',
|
||||
})
|
||||
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -144,17 +126,18 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer,
|
||||
|
||||
config, args = TextGenerationArguments().parse_cli().to_config()
|
||||
print(config, args)
|
||||
pipeline_type = None
|
||||
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
global pipeline_type
|
||||
pipeline_type = cfg.pipeline.type
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
if 'hooks' not in cfg.train:
|
||||
cfg.train['hooks'] = []
|
||||
if args.use_megatron:
|
||||
cfg.train.hooks.append({'type': 'MegatronHook'})
|
||||
if args.deepspeed:
|
||||
cfg.train.hooks.append({
|
||||
'type': 'DeepspeedHook',
|
||||
@@ -166,6 +149,13 @@ def cfg_modify_fn(cfg):
|
||||
return cfg
|
||||
|
||||
|
||||
def custom_save_pretrained(self, *args, **kwargs):
|
||||
config = kwargs.pop('config')
|
||||
if config is not None:
|
||||
config.pipeline = {'type': pipeline_type}
|
||||
TorchModel.save_pretrained(self, *args, config=config, **kwargs)
|
||||
|
||||
|
||||
if args.dataset_json_file is None:
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
@@ -185,6 +175,8 @@ model_dir = snapshot_download(args.model)
|
||||
sys.path.append(model_dir)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, trust_remote_code=True, device_map=args.device_map)
|
||||
model.model_dir = model_dir
|
||||
model.save_pretrained = types.MethodType(custom_save_pretrained, model)
|
||||
cfg_file = os.path.join(model_dir, 'configuration.json')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
|
||||
@@ -208,7 +200,8 @@ preprocessor = TextGenerationTransformersPreprocessor(
|
||||
model_dir,
|
||||
tokenizer=tokenizer,
|
||||
src_txt=config.preprocessor.src_txt,
|
||||
tgt_txt=config.preprocessor.tgt_txt)
|
||||
tgt_txt=config.preprocessor.tgt_txt,
|
||||
sequence_length=getattr(config.preprocessor, 'sequence_length', None))
|
||||
|
||||
if args.use_lora != 0:
|
||||
lora_config = LoRAConfig(
|
||||
@@ -228,7 +221,7 @@ kwargs = dict(
|
||||
seed=args.seed,
|
||||
cfg_modify_fn=cfg_modify_fn,
|
||||
# No placement for model, leave the model to `device_map`
|
||||
device='cpu')
|
||||
device='cpu' if args.device_map else 'gpu')
|
||||
|
||||
trainer: EpochBasedTrainer = build_trainer(
|
||||
name=args.trainer, default_args=kwargs)
|
||||
|
||||
28
examples/pytorch/baichuan/lora_inference.py
Normal file
28
examples/pytorch/baichuan/lora_inference.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.swift import Swift
|
||||
from modelscope.swift.lora import LoRAConfig
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
# 使用源模型 model_id 初始化 pipeline
|
||||
model_id = 'baichuan-inc/baichuan-7B'
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_generation, model=model_id, model_revision='v1.0.2')
|
||||
# lora 配置,replace_modules,rank,alpha 需与训练参数相同
|
||||
lora_config = LoRAConfig(replace_modules=['pack'], rank=32, lora_alpha=32)
|
||||
# 转 bf16,需与训练精度相同
|
||||
model = pipe.model.bfloat16()
|
||||
# model 转 lora
|
||||
Swift.prepare_model(model, lora_config)
|
||||
# 加载 lora 参数,默认 link 到于 output/model 路径
|
||||
work_dir = './tmp'
|
||||
state_dict = torch.load(osp.join(work_dir, 'output/pytorch_model.bin'))
|
||||
model.load_state_dict(state_dict)
|
||||
# 使用 lora model 替换 pipeline 中的 model
|
||||
pipe.model = model
|
||||
# 使用 pipeline 推理
|
||||
result_zh = pipe('今天天气是真的')
|
||||
print(result_zh)
|
||||
@@ -9,13 +9,23 @@ torchrun examples/pytorch/baichuan/finetune_baichuan.py \
|
||||
--val_split 'test' \
|
||||
--src_txt 'text1' \
|
||||
--tgt_txt 'text2' \
|
||||
--max_epochs 1 \
|
||||
--sequence_length 128 \
|
||||
--max_epochs 2 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--per_device_eval_batch_size 32 \
|
||||
--train_data_worker 0 \
|
||||
--eval_data_worker 0 \
|
||||
--optimizer 'AdamW' \
|
||||
--lr 2e-5 \
|
||||
--lr_scheduler 'CosineAnnealingLR' \
|
||||
--eval_strategy 'no' \
|
||||
--eval_strategy 'by_epoch' \
|
||||
--bf16 1 \
|
||||
--use_lora 1 \
|
||||
--eval_metrics 'text-gen-metric' \
|
||||
--use_model_config 1 \
|
||||
--eval_metrics 'ppl' \
|
||||
--T_max 1 \
|
||||
--save_strategy no \
|
||||
--save_best true \
|
||||
--metric_for_best_model ppl \
|
||||
--metric_rule_for_best_model min \
|
||||
--device_map 'auto' \
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metainfo import Metrics, Trainers
|
||||
from modelscope.outputs.outputs import ModelOutputBase
|
||||
from modelscope.trainers import NlpEpochBasedTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.text_generation_trainer)
|
||||
@@ -20,12 +20,19 @@ class TextGenerationTrainer(NlpEpochBasedTrainer):
|
||||
def evaluation_step(self, data):
|
||||
model = self.model.module if self._dist else self.model
|
||||
model.eval()
|
||||
output = dict()
|
||||
|
||||
with torch.no_grad():
|
||||
result = model.generate(data)
|
||||
output.update(self._eval_genarate(model, data))
|
||||
if Metrics.PPL in self.metrics or Metrics.loss_metric in self.metrics:
|
||||
output.update(model.forward(**data))
|
||||
return output
|
||||
|
||||
def _eval_genarate(self, model, data) -> Dict[str, Any]:
|
||||
result = model.generate(data)
|
||||
if isinstance(result, ModelOutputBase):
|
||||
result = result.to_dict()
|
||||
result['preds'] = [self._decode(seq) for seq in result['sequences']]
|
||||
data['tgts'] = [self._decode(seq) for seq in data['labels']]
|
||||
assert len(result['preds']) == len(data['tgts'])
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user